Efficient thresholding filter of an array with numpy

Question:

I need to filter an array to remove the elements that are lower than a certain threshold. My current code is like this:

threshold = 5
a = numpy.array(range(10)) # testing data
b = numpy.array(filter(lambda x: x >= threshold, a))

The problem is that this creates a temporary list, using a filter with a lambda function (slow).

As this is quite a simple operation, maybe there is a numpy function that does it in an efficient way, but I’ve been unable to find it.

I thought that another way to achieve this could be sorting the array, finding the index of the threshold and returning a slice from that index onwards, but even if this would be faster for small inputs (and it won’t be noticeable anyway), it’s definitively asymptotically less efficient as the input size grows.

Update: I took some measurements too, and the sorting + slicing was still twice as fast as the pure python filter when the input was 100.000.000 entries.

r = numpy.random.uniform(0, 1, 100000000)

%timeit test1(r) # filter
# 1 loops, best of 3: 21.3 s per loop

%timeit test2(r) # sort and slice
# 1 loops, best of 3: 11.1 s per loop

%timeit test3(r) # boolean indexing
# 1 loops, best of 3: 1.26 s per loop
Asked By: fortran

||

Answers:

b = a[a>threshold] this should do

I tested as follows:

import numpy as np, datetime
# array of zeros and ones interleaved
lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()

t0 = datetime.datetime.now()
flt = lrg[lrg==0]
print datetime.datetime.now() - t0

t0 = datetime.datetime.now()
flt = np.array(filter(lambda x_x==0, lrg))
print datetime.datetime.now() - t0

I got

$ python test.py
0:00:00.028000
0:00:02.461000

http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays

Answered By: yosukesabai

You can also use np.where to get the indices where the condition is True and use advanced indexing.

import numpy as np
b = a[np.where(a >= threshold)]

One useful function of np.where is that you can use it to replace values (e.g. replace values where the threshold is not met). While a[a <= 5] = 0 modifies a, np.where returns a new array with the same shape only with some values (potentially) changed.

a = np.array([3, 7, 2, 6, 1])
b = np.where(a >= 5, a, 0)       # array([0, 7, 0, 6, 0])

It’s also very competitive in terms of performance.

a, threshold = np.random.uniform(0,1,100000000), 0.5

%timeit a[a >= threshold]
# 1.22 s ± 92.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit a[np.where(a >= threshold)]
# 1.34 s ± 258 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Answered By: cottontail
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.