MATLAB-style find() function in Python

Question:

In MATLAB it is easy to find the indices of values that meet a particular condition:

>> a = [1,2,3,1,2,3,1,2,3];
>> find(a > 2)     % find the indecies where this condition is true
[3, 6, 9]          % (MATLAB uses 1-based indexing)
>> a(find(a > 2))  % get the values at those locations
[3, 3, 3]

What would be the best way to do this in Python?

So far, I have come up with the following. To just get the values:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> [val for val in a if val > 2]
[3, 3, 3]

But if I want the index of each of those values it’s a bit more complicated:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> inds = [i for (i, val) in enumerate(a) if val > 2]
>>> inds
[2, 5, 8]
>>> [val for (i, val) in enumerate(a) if i in inds]
[3, 3, 3]

Is there a better way to do this in Python, especially for arbitrary conditions (not just ‘val > 2’)?

I found functions equivalent to MATLAB ‘find’ in NumPy but I currently do not have access to those libraries.

Asked By: user344226

||

Answers:

in numpy you have where :

>> import numpy as np
>> x = np.random.randint(0, 20, 10)
>> x
array([14, 13,  1, 15,  8,  0, 17, 11, 19, 13])
>> np.where(x > 10)
(array([0, 1, 3, 6, 7, 8, 9], dtype=int64),)
Answered By: joaquin

You can make a function that takes a callable parameter which will be used in the condition part of your list comprehension. Then you can use a lambda or other function object to pass your arbitrary condition:

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)

>>> inds
[2, 5, 8]

It’s a little closer to your Matlab example, without having to load up all of numpy.

Answered By: John

To get values with arbitrary conditions, you could use filter() with a lambda function:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> filter(lambda x: x > 2, a)
[3, 3, 3]

One possible way to get the indices would be to use enumerate() to build a tuple with both indices and values, and then filter that:

>>> a = [1,2,3,1,2,3,1,2,3]
>>> aind = tuple(enumerate(a))
>>> print aind
((0, 1), (1, 2), (2, 3), (3, 1), (4, 2), (5, 3), (6, 1), (7, 2), (8, 3))
>>> filter(lambda x: x[1] > 2, aind)
((2, 3), (5, 3), (8, 3))
Answered By: Blair

Why not just use this:

[i for i in range(len(a)) if a[i] > 2]

or for arbitrary conditions, define a function f for your condition and do:

[i for i in range(len(a)) if f(a[i])]
Answered By: JasonFruit

I’ve been trying to figure out a fast way to do this exact thing, and here is what I stumbled upon (uses numpy for its fast vector comparison):

a_bool = numpy.array(a) > 2
inds = [i for (i, val) in enumerate(a_bool) if val]

It turns out that this is much faster than:

inds = [i for (i, val) in enumerate(a) if val > 2]

It seems that Python is faster at comparison when done in a numpy array, and/or faster at doing list comprehensions when just checking truth rather than comparison.

Edit:

I was revisiting my code and I came across a possibly less memory intensive, bit faster, and super-concise way of doing this in one line:

inds = np.arange( len(a) )[ a < 2 ]
Answered By: Nate

Or use numpy’s nonzero function:

import numpy as np
a    = np.array([1,2,3,4,5])
inds = np.nonzero(a>2)
a[inds] 
array([3, 4, 5])
Answered By: vincentv

Matlab’s find code has two arguments. John’s code accounts for the first argument but not the second. For instance, if you want to know where in the index the condition is satisfied: Mtlab’s function would be:

find(x>2,1)

Using John’s code, all you have to do is add a [x] at the end of the indices function, where x is the index number you’re looking for.

def indices(a, func):
    return [i for (i, val) in enumerate(a) if func(val)]

a = [1, 2, 3, 1, 2, 3, 1, 2, 3]

inds = indices(a, lambda x: x > 2)[0] #[0] being the 2nd matlab argument

which returns >>> 2, the first index to exceed 2.

Answered By: Clayton Pipkin

The numpy routine more commonly used for this application is numpy.where(); though, I believe it works the same as numpy.nonzero().

import numpy
a    = numpy.array([1,2,3,4,5])
inds = numpy.where(a>2)

To get the values, you can either store the indices and slice withe them:

a[inds]

or you can pass the array as an optional parameter:

numpy.where(a>2, a)

or multiple arrays:

b = numpy.array([11,22,33,44,55])
numpy.where(a>2, a, b)
Answered By: ryanjdillon

I think I may have found one quick and simple substitute.
BTW I felt that the np.where() function not very satisfactory, in a sense that somehow it contains an annoying row of zero-element.

import matplotlib.mlab as mlab
a = np.random.randn(1,5)
print a

>> [[ 1.36406736  1.45217257 -0.06896245  0.98429727 -0.59281957]]

idx = mlab.find(a<0)
print idx
type(idx)

>> [2 4]
>> np.ndarray

Best,
Da

Answered By: DidasW
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.