Python most efficient way to find index of maximum in partially changed array

Question:

I have a complex-valued array of about 750000 elements for which I repeatedly (say 10^6 or more times) update 1000 (or less) different elements. In the absolute-squared array I then need to find the index of the maximum. This is part of a larger code which takes about ~700 seconds to run. Out of these, typically 75% (~550 sec) are spent on finding the index of the maximum. Even though ndarray.argmax() is "blazingly fast" according to https://stackoverflow.com/a/26820109/5269892, running it repeatedly on an array of 750000 elements (even though only 1000 elements have been changed) just takes too much time.

Below is a minimal, complete example, in which I use random numbers and indices. You may not make assumptions about how the real-valued array 'b' changes after an update (i.e. the values may be larger, smaller or equal), except, if you must, that the array at the index of the previous maximum value ('b[imax]') will likely be smaller after an update.

I tried using sorted arrays into which only the updated values (in sorted order) are inserted at the correct place to maintain sorting, because then we know the maximum is always at index -1 and we do not have to recompute it. The minimal example below includes timings. Unfortunately, selecting the non-updated values and inserting the updated values takes too much time (all other steps combined would require only ~210 us instead of the ~580 us of the ndarray.argmax()).

Context: This is part of an implementation of the deconvolution algorithm CLEAN (Hoegbom, 1974) in the efficient Clark (1980) variant. As I intend to implement the Sequence CLEAN algorithm (Bose+, 2002), where even more iterations are required, or maybe want to use larger input arrays, my question is:

Question: What is the fastest way to find the index of the maximum value in the updated array (without applying ndarray.argmax() to the whole array in each iteration)?

Minimal example code (run on python 3.7.6, numpy 1.21.2, scipy 1.6.0):

import numpy as np

# some array shapes ('nnu_use' and 'nm'), number of total values ('nvals'), number of selected values ('nsel'; here
# 'nsel' == 'nvals'; in general 'nsel' <= 'nvals') and number of values to be changed ('nchange')
nnu_use, nm = 10418//2 + 1, 144
nvals = nnu_use * nm
nsel = nvals
nchange = 1000

# fix random seed, generate random 2D 'Fourier transform' ('a', complex-valued), compute power ('b', real-valued), and
# two 2D arrays for indices of axes 0 and 1
np.random.seed(100)
a = np.random.rand(nsel) + 1j * np.random.rand(nsel)
b = a.real ** 2 + a.imag ** 2
inu_2d = np.tile(np.arange(nnu_use)[:,None], (1,nm))
im_2d = np.tile(np.arange(nm)[None,:], (nnu_use,1))

# select 'nsel' random indices and get 1D arrays of the selected 2D indices
isel = np.random.choice(nvals, nsel, replace=False)
inu_sel, im_sel = inu_2d.flatten()[isel], im_2d.flatten()[isel]

def do_update_iter(a, b):
    # find index of maximum, choose 'nchange' indices of which 'nchange - 1' are random and the remaining one is the
    # index of the maximum, generate random complex numbers, update 'a' and compute updated 'b'
    imax = b.argmax()
    ichange = np.concatenate(([imax],np.random.choice(nsel, nchange-1, replace=False)))
    a_change = np.random.rand(nchange) + 1j*np.random.rand(nchange)
    a[ichange] = a_change
    b[ichange] = a_change.real ** 2 + a_change.imag ** 2
    return a, b, ichange

# do an update iteration on 'a' and 'b'
a, b, ichange = do_update_iter(a, b)

# sort 'a', 'b', 'inu_sel' and 'im_sel'
i_sort = b.argsort()
a_sort, b_sort, inu_sort, im_sort = a[i_sort], b[i_sort], inu_sel[i_sort], im_sel[i_sort]

# do an update iteration on 'a_sort' and 'b_sort'
a_sort, b_sort, ichange = do_update_iter(a_sort, b_sort)
b_sort_copy = b_sort.copy()

ind = np.arange(nsel)

def binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange):
    # binary insertion as an idea to save computation time relative to repeated argmax over entire (large) arrays
    # find updated values for 'a_sort', compute updated values for 'b_sort'
    a_change = a_sort[ichange]
    b_change = a_change.real ** 2 + a_change.imag ** 2
    # sort the updated values for 'a_sort' and 'b_sort' as well as the corresponding indices
    i_sort = b_change.argsort()
    a_change_sort = a_change[i_sort]
    b_change_sort = b_change[i_sort]
    inu_change_sort = inu_sort[ichange][i_sort]
    im_change_sort = im_sort[ichange][i_sort]
    # find indices of the non-updated values, cut out those indices from 'a_sort', 'b_sort', 'inu_sort' and 'im_sort'
    ind_complement = np.delete(ind, ichange)
    a_complement = a_sort[ind_complement]
    b_complement = b_sort[ind_complement]
    inu_complement = inu_sort[ind_complement]
    im_complement = im_sort[ind_complement]
    # find indices where sorted updated elements would have to be inserted into the sorted non-updated arrays to keep
    # the merged arrays sorted and insert the elements at those indices
    i_insert = b_complement.searchsorted(b_change_sort)
    a_updated = np.insert(a_complement, i_insert, a_change_sort)
    b_updated = np.insert(b_complement, i_insert, b_change_sort)
    inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
    im_updated = np.insert(im_complement, i_insert, im_change_sort)

    return a_updated, b_updated, inu_updated, im_updated

# do the binary insertion
a_updated, b_updated, inu_updated, im_updated = binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)

# do all the steps of the binary insertion, just to have the variable names defined
a_change = a_sort[ichange]
b_change = a_change.real ** 2 + a_change.imag ** 2
i_sort = b_change.argsort()
a_change_sort = a_change[i_sort]
b_change_sort = b_change[i_sort]
inu_change_sort = inu_sort[ichange][i_sort]
im_change_sort = im_sort[ichange][i_sort]
ind_complement = np.delete(ind, i_sort)
a_complement = a_sort[ind_complement]
b_complement = b_sort[ind_complement]
inu_complement = inu_sort[ind_complement]
im_complement = im_sort[ind_complement]
i_insert = b_complement.searchsorted(b_change_sort)
a_updated = np.insert(a_complement, i_insert, a_change_sort)
b_updated = np.insert(b_complement, i_insert, b_change_sort)
inu_updated = np.insert(inu_complement, i_insert, inu_change_sort)
im_updated = np.insert(im_complement, i_insert, im_change_sort)

# timings for argmax and for sorting
%timeit b.argmax()             # 579 µs ± 1.26 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit b_sort.argmax()        # 580 µs ± 810 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit np.sort(b)             # 70.2 ms ± 120 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit np.sort(b_sort)        # 25.2 ms ± 109 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit b_sort_copy.sort()     # 14 ms ± 78.2 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

# timings for binary insertion
%timeit binary_insert(a_sort, b_sort, inu_sort, im_sort, ichange)          # 33.7 ms ± 208 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit a_change = a_sort[ichange]                                         # 4.28 µs ± 40.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change = a_change.real ** 2 + a_change.imag ** 2                 # 8.25 µs ± 127 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit i_sort = b_change.argsort()                                        # 35.6 µs ± 529 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_change_sort = a_change[i_sort]                                   # 4.2 µs ± 62.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit b_change_sort = b_change[i_sort]                                   # 2.05 µs ± 47 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit inu_change_sort = inu_sort[ichange][i_sort]                        # 4.47 µs ± 38 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit im_change_sort = im_sort[ichange][i_sort]                          # 4.51 µs ± 48.5 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
%timeit ind_complement = np.delete(ind, ichange)                           # 1.38 ms ± 25.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit a_complement = a_sort[ind_complement]                              # 3.52 ms ± 31.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_complement = b_sort[ind_complement]                              # 1.44 ms ± 256 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_complement = inu_sort[ind_complement]                          # 1.36 ms ± 6.61 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_complement = im_sort[ind_complement]                            # 1.31 ms ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit i_insert = b_complement.searchsorted(b_change_sort)                # 148 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit a_updated = np.insert(a_complement, i_insert, a_change_sort)       # 3.08 ms ± 111 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit b_updated = np.insert(b_complement, i_insert, b_change_sort)       # 1.37 ms ± 16.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit inu_updated = np.insert(inu_complement, i_insert, inu_change_sort) # 1.41 ms ± 28 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit im_updated = np.insert(im_complement, i_insert, im_change_sort)    # 1.52 ms ± 173 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

Update: As suggested below by @Jérôme Richard, a fast way to repeatedly find the index of the maximum in a partially updated array is to split the array into chunks, pre-compute the maxima of the chunks, and then in each iteration re-compute the maxima of only the nchange (or less) updated chunks, followed by computing the argmax over the chunk maxima, returning the chunk index, and finding the argmax within the chunk of that chunk index.

I copied the code from @Jérôme Richard’s answer. In practice, his solution, when run on my system, results in a speed-boost of about 7.3, requiring 46.6 + 33 = 79.6 musec instead of 580 musec for b.argmax().

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(b):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk
# OMP: Info #276: omp_set_nested routine deprecated, please use omp_set_max_active_levels instead.

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(b, max_per_chunk):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(b, max_per_chunk, ichange):
    # Required for this simplified version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

b = np.random.rand(nsel)
max_per_chunk = precompute_max_per_chunk(b)
a, b, ichange = do_update_iter(a, b)
argmax_from_chunks(b, max_per_chunk)
update_max_per_chunk(b, max_per_chunk, ichange)

%timeit max_per_chunk = precompute_max_per_chunk(b)     # 77.3 µs ± 17.2 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks(b, max_per_chunk)            # 46.6 µs ± 11.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk(b, max_per_chunk, ichange) # 33 µs ± 40.6 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Update 2: I now modified @Jérôme Richard’s solution to work with arrays b having a size not equal to an integer multiple of the chunk size. In addition the code only accesses all chunk values if an updated value is smaller than the previous chunk maximum, else directly sets the updated value as the new chunk maximum. The if-queries should require a small time compared to the time saving when the updated value is larger than the previous maximum. In my code, this case will become more and more likely the more iterations have passed (the updated values get closer and closer to noise, i.e. random). In practice, for random numbers, the execution time for update_max_per_chunk() gets reduced a bit further, from ~33 us to ~27 us. The code and new timings are:

import math

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk_bp(b):
    nchunks = math.ceil(b.size/32)
    imod = b.size % 32
    max_per_chunk = np.empty(nchunks)
    
    for chunk_idx in nb.prange(nchunks):
        offset = chunk_idx * 32
        maxi = b[offset]
        if (chunk_idx != (nchunks - 1)) or (not imod):
            iend = 32
        else:
            iend = imod
        for j in range(1, iend):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks_bp(b, max_per_chunk):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    if (chunk_idx != (nchunks - 1)) or (not imod):
        return offset + np.argmax(b[offset:offset+32])
    else:
        return offset + np.argmax(b[offset:offset+imod])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk_bp(b, max_per_chunk, ichange):
    nchunks = max_per_chunk.size
    imod = b.size % 32
    for idx in ichange:
        chunk_idx = idx // 32
        if b[idx] < max_per_chunk[chunk_idx]:
            offset = chunk_idx * 32
            if (chunk_idx != (nchunks - 1)) or (not imod):
                iend = 32
            else:
                iend = imod
            maxi = b[offset]
            for j in range(1, iend):
                maxi = max(b[offset + j], maxi)
            max_per_chunk[chunk_idx] = maxi
        else:
            max_per_chunk[chunk_idx] = b[idx]

%timeit max_per_chunk = precompute_max_per_chunk_bp(b)     # 74.6 µs ± 29.8 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit argmax_from_chunks_bp(b, max_per_chunk)            # 46.6 µs ± 9.92 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit update_max_per_chunk_bp(b, max_per_chunk, ichange) # 26.5 µs ± 19.1 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Asked By: bproxauf

||

Answers:

ndarray.argmax() is "blazingly fast" according to https://stackoverflow.com/a/26820109/5269892

Argmax is not optimal since it does not succeed to saturate the RAM bandwidth on my machine (which is possible), but it is very good since it saturate ~40% of the total RAM throughput in your case and about 65%-70% in sequential on my machine (one core cannot saturate the RAM on most machine). Most machine have a lower throughput so np.argmax should be even closer to the optimal on these machine.

Finding the maximum value using multiple threads can help to reach the optimal but regarding the current performance of the funciton, one should not expect a speed up greater than 2 on most PC (more on computing servers).

What is the fastest way to find the index of the maximum value in the updated array

Whatever the computation done, reading the whole array in memory takes at least b.size * 8 / RAM_throughput seconds. With a very-good 2-channels DDR4 RAM, the optimal time is about to ~125 us, while the best 1-channel DDR4 RAM achieve ~225 us. If the array is written in-place, the optimal times is twice bigger and if a new array is created (out-of-place computation), then it is 3 time bigger on x86-64 platforms. In fact, this is even worse for the latter because of big overheads of the OS virtual memory.

What this means is that no out-of-place computation reading the whole array can beat np.argmax on a mainstream PC. This also explains why the sort solution is so slow: it creates many temporary arrays. Even a perfect sorted array strategy would be not much faster than np.argmax here (because all items need to be moved in RAM in the worst case and far more than half in average). In fact, the benefit of any in-place methods writing the whole array is low (still on a mainstream PC): it would only be slightly faster than np.argmax. The only solution to get a significant speed up is not to operate on the whole array.

One efficient solution to solve this problem is to use a balanced binary search tree. Indeed, you can remove the k nodes from a tree containing n nodes in O(k log n) time. You can then insert the updated values in the same time. This is much better than a O(n) solution in your case because n ~= 750_000 and k ~= 1_000. Still, note that there is an hidden factor behind the complexity and binary search tree may not be so fast in practice, especially if they are not very optimized. Also note that it is better to update the tree value than to delete nodes and insert new ones. A pure-Python implementation will hardly be fast enough in this case (and take a lot a memory). Only **Cython or a native solution can be fast (eg. C/C++, or any Python module implemented natively but I could not find any one that are fast).

Another alternative is a static n-ary tree-based partial maximums data structure. It consist in splitting the array in chunks and pre-computing the maximum of each chunks first. When values are updated (and assuming the number of items is constant), you need to (1) recompute the maximum of each chunk. To compute the global maximum, you need to (2) compute the maximum of each chunk maximum value. This solution also require a (semi) native implementation so to be fast since Numpy introduces significant overheads during the update of the per-chunk maximum values (because it is not very optimized for such a case), but one should certainly see a speed up. Numba and Cython can be used to do so for example. The size of the chunks need to be carefully chosen. In your case something between 16 to 32 should gives you a huge speed up.

With chunks of size 32, only at most 32*k=32_000 values needs to be read to recompute the total maximum (up to 1000 values are written). This is far less than 750_000. The update of the partial maximums require to compute the maximum value of a n/32 ~= 23_400 values which is still relatively small. I expect this to be at 5 time faster with an optimized implementation, but probably even >10 times faster in practice, especially, using a parallel implementation. This is certainly the best solution (without additional assumptions).


Implementation in Numba

Here is a (barely tested) Numba implementation:

import numba as nb

@nb.njit('(f8[::1],)', parallel=True)
def precompute_max_per_chunk(arr):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    max_per_chunk = np.empty(b.size // 32)

    for chunk_idx in nb.prange(b.size//32):
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

    return max_per_chunk

@nb.njit('(f8[::1], f8[::1])')
def argmax_from_chunks(arr, max_per_chunk):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    chunk_idx = np.argmax(max_per_chunk)
    offset = chunk_idx * 32
    return offset + np.argmax(b[offset:offset+32])

@nb.njit('(f8[::1], f8[::1], i8[::1])')
def update_max_per_chunk(arr, max_per_chunk, ichange):
    # Required for this simplied version to work and be simple
    assert b.size % 32 == 0
    assert max_per_chunk.size == b.size // 32

    for idx in ichange:
        chunk_idx = idx // 32
        offset = chunk_idx * 32
        maxi = b[offset]
        for j in range(1, 32):
            maxi = max(b[offset + j], maxi)
        max_per_chunk[chunk_idx] = maxi

Here is an example of how to use it and timings on my (6-core) machine:

# Precomputation (306 µs)
max_per_chunk = precompute_max_per_chunk(b)

# Computation of the global max from the chunks (22.3 µs)
argmax_from_chunks(b, max_per_chunk)

# Update of the chunks (25.2 µs)
update_max_per_chunk(b, max_per_chunk, ichange)

# Initial best implementation: 357 µs
np.argmax(b)

As you can see, it is pretty fast. Updates should takes 22.3+25.2 = 47.5 µs, while the Numpy naive implementation takes 357 µs. So the Numba implementation is 7.5 times faster! I think it can be optimized a bit further but it is not simple. Note the update is sequential and the pre-computation is parallel. Fun fact: the pre-computation followed by a call to argmax_from_chunks is faster than np.argmax thanks to the use of multiple threads!


Further improvements

The argmax_from_chunks can be improved thanks to SIMD instruction. Indeed, the current implementation generates the scalar maxsd/vmaxsd instruction on x86-64 machines which is sub-optimal. The operation can be vectorized by using a tile-based argmin computing the maximum with a x4 unrolled loop (possibly even x8 on recent 512-bit wide SIMD machines). On my processor supporting the AVX instruction set, experiments shows that Numba can generate a code running in 6-7 us (about 4 times faster). That being said, this is tricky to implement and the resulting function is a bit ugly.

The same method can be used to also speed up update_max_per_chunk which is unfortunately also not vectorized by default. I also expect a ~4x speed up on a recent x86-64 machine. However, Numba generate a very inefficient vectorization method in many case (it tries to vectorize the outer loop instead of the inner one). As a result my best attempt with Numba reached 16.5 us.

In theory, the whole update can be made about 4 times faster on a mainstream x86-64 machine, though in practice a 2 time faster code is at least possible!

Answered By: Jérôme Richard