Finding the index of a numpy array in a list
Question:
import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]
I would expect
foo.index( np.array([[1,2,3]]) )
to return
2
but instead I get
ValueError: The truth value of an array with more than one element is
ambiguous. Use a.any() or a.all()
anything better than my current solution? It seems inefficient.
def find_index_of_array(list, array):
for i in range(len(list)):
if np.all(list[i]==array):
return i
find_index_of_array(foo, np.array([[1,2,3]]) )
# 2
Answers:
The reason for the error here is obviously because numpy’s ndarray overrides ==
to return an array rather than a boolean.
AFAIK, there is no simple solution here. The following will work so long as the
np.all(val == array)
bit works.
next((i for i, val in enumerate(lst) if np.all(val == array)), -1)
Whether that bit works or not depends critically on what the other elements in the array are and if they can be compared with numpy arrays.
For performance, you might want to process only the NumPy arrays in the input list. So, we could type-check before going into the loop and index into the elements that are arrays.
Thus, an implementation would be –
def find_index_of_array_v2(list1, array1):
idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
for i in idx:
if np.all(list1[i]==array1):
return i
How about this one?
arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)
# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]
# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]
Output:
[2]
Note: This will also work even if foo
is a list. I just put it as a numpy
array here.
The issue here (you probably know already but just to repeat it) is that list.index
works along the lines of:
for idx, item in enumerate(your_list):
if item == wanted_item:
return idx
The line if item == wanted_item
is the problem, because it implicitly converts item == wanted_item
to a boolean. But numpy.ndarray
(except if it’s a scalar) raises this ValueError
then:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Solution 1: adapter (thin wrapper) class
I generally use a thin wrapper (adapter) around numpy.ndarray
whenever I need to use python functions like list.index
:
class ArrayWrapper(object):
__slots__ = ["_array"] # minimizes the memory footprint of the class.
def __init__(self, array):
self._array = array
def __eq__(self, other_array):
# array_equal also makes sure the shape is identical!
# If you don't mind broadcasting you can also use
# np.all(self._array == other_array)
return np.array_equal(self._array, other_array)
def __array__(self):
# This makes sure that `np.asarray` works and quite fast.
return self._array
def __repr__(self):
return repr(self._array)
These thin wrappers are more expensive than manually using some enumerate
loop or comprehension but you don’t have to re-implement the python functions. Assuming the list contains only numpy-arrays (otherwise you need to do some if ... else ...
checking):
list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]
After this step you can use all your python functions on this list:
>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1
These wrappers are not numpy-arrays anymore but you have thin wrappers so the extra list is quite small. So depending on your needs you could keep the wrapped list and the original list and choose on which to do the operations, for example you can also list.count
the identical arrays now:
>>> list_of_wrapped_arrays.count(np.ones((3)))
2
or list.remove
:
>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]]),
array([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]]),
array([ 1., 1., 1.])]
Solution 2: subclass and ndarray.view
This approach uses explicit subclasses of numpy.array
. It has the advantage that you get all builtin array-functionality and only modify the requested operation (which would be __eq__
):
class ArrayWrapper(np.ndarray):
def __eq__(self, other_array):
return np.array_equal(self, other_array)
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
Again you get most list methods this way: list.remove
, list.count
besides list.index
.
However this approach may yield subtle behaviour if some operation implicitly uses __eq__
. You can always re-interpret is as plain numpy array by using np.asarray
or .view(np.ndarray)
:
>>> view_list[1]
ArrayWrapper([ 2., 2., 2.])
>>> view_list[1].view(np.ndarray)
array([ 2., 2., 2.])
>>> np.asarray(view_list[1])
array([ 2., 2., 2.])
Alternative: Overriding __bool__
(or __nonzero__
for python 2)
Instead of fixing the problem in the __eq__
method you could also override __bool__
or __nonzero__
:
class ArrayWrapper(np.ndarray):
# This could also be done in the adapter solution.
def __bool__(self):
return bool(np.all(self))
__nonzero__ = __bool__
Again this makes the list.index
work like intended:
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
But this will definitly modify more behaviour! For example:
>>> if ArrayWrapper([1,2,3]):
... print('that was previously impossible!')
that was previously impossible!
This should do the job:
[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]
you can use a view to override the equals method
import numpy as np
class Vector(np.ndarray):
def __eq__(self, other: np.ndarray) -> bool:
return np.array_equal(super(),other)
data=list(np.random.random((100,3)))
element=data[3]
print(data.index(element.view(Vector))) #prints 3
print(element.view(Vector) in data) #prints True
import numpy as np
foo = [1, "hello", np.array([[1,2,3]]) ]
I would expect
foo.index( np.array([[1,2,3]]) )
to return
2
but instead I get
ValueError: The truth value of an array with more than one element is
ambiguous. Use a.any() or a.all()
anything better than my current solution? It seems inefficient.
def find_index_of_array(list, array):
for i in range(len(list)):
if np.all(list[i]==array):
return i
find_index_of_array(foo, np.array([[1,2,3]]) )
# 2
The reason for the error here is obviously because numpy’s ndarray overrides ==
to return an array rather than a boolean.
AFAIK, there is no simple solution here. The following will work so long as the np.all(val == array)
bit works.
next((i for i, val in enumerate(lst) if np.all(val == array)), -1)
Whether that bit works or not depends critically on what the other elements in the array are and if they can be compared with numpy arrays.
For performance, you might want to process only the NumPy arrays in the input list. So, we could type-check before going into the loop and index into the elements that are arrays.
Thus, an implementation would be –
def find_index_of_array_v2(list1, array1):
idx = np.nonzero([type(i).__module__ == np.__name__ for i in list1])[0]
for i in idx:
if np.all(list1[i]==array1):
return i
How about this one?
arr = np.array([[1,2,3]])
foo = np.array([1, 'hello', arr], dtype=np.object)
# if foo array is of heterogeneous elements (str, int, array)
[idx for idx, el in enumerate(foo) if type(el) == type(arr)]
# if foo array has only numpy arrays in it
[idx for idx, el in enumerate(foo) if np.array_equal(el, arr)]
Output:
[2]
Note: This will also work even if foo
is a list. I just put it as a numpy
array here.
The issue here (you probably know already but just to repeat it) is that list.index
works along the lines of:
for idx, item in enumerate(your_list):
if item == wanted_item:
return idx
The line if item == wanted_item
is the problem, because it implicitly converts item == wanted_item
to a boolean. But numpy.ndarray
(except if it’s a scalar) raises this ValueError
then:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Solution 1: adapter (thin wrapper) class
I generally use a thin wrapper (adapter) around numpy.ndarray
whenever I need to use python functions like list.index
:
class ArrayWrapper(object):
__slots__ = ["_array"] # minimizes the memory footprint of the class.
def __init__(self, array):
self._array = array
def __eq__(self, other_array):
# array_equal also makes sure the shape is identical!
# If you don't mind broadcasting you can also use
# np.all(self._array == other_array)
return np.array_equal(self._array, other_array)
def __array__(self):
# This makes sure that `np.asarray` works and quite fast.
return self._array
def __repr__(self):
return repr(self._array)
These thin wrappers are more expensive than manually using some enumerate
loop or comprehension but you don’t have to re-implement the python functions. Assuming the list contains only numpy-arrays (otherwise you need to do some if ... else ...
checking):
list_of_wrapped_arrays = [ArrayWrapper(arr) for arr in list_of_arrays]
After this step you can use all your python functions on this list:
>>> list_of_arrays = [np.ones((3, 3)), np.ones((3)), np.ones((3, 3)) * 2, np.ones((3))]
>>> list_of_wrapped_arrays.index(np.ones((3,3)))
0
>>> list_of_wrapped_arrays.index(np.ones((3)))
1
These wrappers are not numpy-arrays anymore but you have thin wrappers so the extra list is quite small. So depending on your needs you could keep the wrapped list and the original list and choose on which to do the operations, for example you can also list.count
the identical arrays now:
>>> list_of_wrapped_arrays.count(np.ones((3)))
2
or list.remove
:
>>> list_of_wrapped_arrays.remove(np.ones((3)))
>>> list_of_wrapped_arrays
[array([[ 1., 1., 1.],
[ 1., 1., 1.],
[ 1., 1., 1.]]),
array([[ 2., 2., 2.],
[ 2., 2., 2.],
[ 2., 2., 2.]]),
array([ 1., 1., 1.])]
Solution 2: subclass and ndarray.view
This approach uses explicit subclasses of numpy.array
. It has the advantage that you get all builtin array-functionality and only modify the requested operation (which would be __eq__
):
class ArrayWrapper(np.ndarray):
def __eq__(self, other_array):
return np.array_equal(self, other_array)
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
Again you get most list methods this way: list.remove
, list.count
besides list.index
.
However this approach may yield subtle behaviour if some operation implicitly uses __eq__
. You can always re-interpret is as plain numpy array by using np.asarray
or .view(np.ndarray)
:
>>> view_list[1]
ArrayWrapper([ 2., 2., 2.])
>>> view_list[1].view(np.ndarray)
array([ 2., 2., 2.])
>>> np.asarray(view_list[1])
array([ 2., 2., 2.])
Alternative: Overriding __bool__
(or __nonzero__
for python 2)
Instead of fixing the problem in the __eq__
method you could also override __bool__
or __nonzero__
:
class ArrayWrapper(np.ndarray):
# This could also be done in the adapter solution.
def __bool__(self):
return bool(np.all(self))
__nonzero__ = __bool__
Again this makes the list.index
work like intended:
>>> your_list = [np.ones(3), np.ones(3)*2, np.ones(3)*3, np.ones(3)*4]
>>> view_list = [arr.view(ArrayWrapper) for arr in your_list]
>>> view_list.index(np.array([2,2,2]))
1
But this will definitly modify more behaviour! For example:
>>> if ArrayWrapper([1,2,3]):
... print('that was previously impossible!')
that was previously impossible!
This should do the job:
[i for i,j in enumerate(foo) if j.__class__.__name__=='ndarray']
[2]
you can use a view to override the equals method
import numpy as np
class Vector(np.ndarray):
def __eq__(self, other: np.ndarray) -> bool:
return np.array_equal(super(),other)
data=list(np.random.random((100,3)))
element=data[3]
print(data.index(element.view(Vector))) #prints 3
print(element.view(Vector) in data) #prints True