numpy: what is the logic of the argmin() and argmax() functions?

Question:

I can not understand the output of argmax and argmin when use with the axis parameter. For example:

>>> a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]])
>>> a
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])
>>> a.shape
(3, 4)
>>> a.size
12
>>> np.argmax(a)
5
>>> np.argmax(a,axis=0)
array([1, 1, 1, 1])
>>> np.argmax(a,axis=1)
array([3, 1, 1])
>>> np.argmin(a)
0
>>> np.argmin(a,axis=0)
array([0, 0, 2, 2])
>>> np.argmin(a,axis=1)
array([0, 2, 2])

As you can see, the maximum value is the point (1,1) and the minimum one is the point (0,0). So in my logic when I run:

  • np.argmin(a,axis=0) I expected array([0,0,0,0])
  • np.argmin(a,axis=1) I expected array([0,0,0])
  • np.argmax(a,axis=0) I expected array([1,1,1,1])
  • np.argmax(a,axis=1) I expected array([1,1,1])

What is wrong with my understanding of things?

Asked By: user4584333

||

Answers:

By adding the axis argument, NumPy looks at the rows and columns individually. When it’s not given, the array a is flattened into a single 1D array.

axis=0 means that the operation is performed down the columns of a 2D array a in turn.

For example np.argmin(a, axis=0) returns the index of the minimum value in each of the four columns. The minimum value in each column is shown in bold below:

>>> a
array([[ 1,  2,  4,  7],  # 0
       [ 9, 88,  6, 45],  # 1
       [ 9, 76,  3,  4]]) # 2

>>> np.argmin(a, axis=0)
array([0, 0, 2, 2])

On the other hand, axis=1 means that the operation is performed across the rows of a.

That means np.argmin(a, axis=1) returns [0, 2, 2] because a has three rows. The index of the minimum value in the first row is 0, the index of the minimum value of the second and third rows is 2:

>>> a
#        0   1   2   3
array([[ 1,  2,  4,  7],
       [ 9, 88,  6, 45],
       [ 9, 76,  3,  4]])

>>> np.argmin(a, axis=1)
array([0, 2, 2])
Answered By: Alex Riley

The np.argmax function by default works along the flattened array, unless you specify an axis. To see what is happening you can use flatten explicitly:

np.argmax(a)
>>> 5

a.flatten()
>>>> array([ 1,  2,  4,  7,  9, 88,  6, 45,  9, 76,  3,  4])
             0   1   2   3   4   5 

I’ve numbered the indices under the array above to make it clearer. Note that indices are numbered from zero in numpy.

In the cases where you specify the axis, it is also working as expected:

np.argmax(a,axis=0)
>>> array([1, 1, 1, 1])

This tells you that the largest value is in row 1 (2nd value), for each column along axis=0 (down). You can see this more clearly if you change your data a bit:

a=np.array([[100,2,4,7],[9,88,6,45],[9,76,3,100]])
a
>>> array([[100,   2,   4,   7],
           [  9,  88,   6,  45],
           [  9,  76,   3, 100]])

np.argmax(a, axis=0)
>>> array([0, 1, 1, 2])

As you can see it now identifies the maximum value in row 0 for column 1, row 1 for column 2 and 3 and row 3 for column 4.

There is a useful guide to numpy indexing in the documentation.

Answered By: mfitzp

The axis in the argmax function argument, refers to the axis along which the array will be sliced.

In another word, np.argmin(a,axis=0) is effectively the same as np.apply_along_axis(np.argmin, 0, a), that is to find out the minimum location for these sliced vectors along the axis=0.

Therefore in your example, np.argmin(a, axis=0) is [0, 0, 2, 2] which corresponding to values of [1, 2, 3, 4] on respective columns

Answered By: xingzhi.sg

As a side note: if you want to find the coordinates of your maximum value in the full array, you can use

a=np.array([[1,2,4,7],[9,88,6,45],[9,76,3,4]])
>>> a
[[ 1  2  4  7]
 [ 9 88  6 45]
 [ 9 76  3  4]]

c=(np.argmax(a)/len(a[0]),np.argmax(a)%len(a[0]))
>>> c
(1, 1)
Answered By: MartijnVanAttekum
""" ....READ THE COMMENTS FOR CLARIFICATION....."""

import numpy as np
a = np.array([[1,2,4,7], [9,88,6,45], [9,76,3,4]])

"""np.argmax(a) will give index of max value in flatted array of given matrix """
>>np.argmax(a)
5

"""np.argmax(a,axis=0) will return list of indexes of  max value column-wise"""
>>print(np.argmax(a,axis=0))
[1,1,1,1]

"""np.argmax(a,axis=1) will return list of indexes of  max value row-wise"""
>>print(np.argmax(a,axis=1))
[3,1,1]

"""np.argmin(a) will give index of min value in flatted array of given matrix """
>>np.argmin(a)
0

"""np.argmin(a,axis=0) will return list of indexes of  min value column-wise"""
>>print(np.argmin(a,axis=0))
[0,0,2,2]

"""np.argmin(a,axis=0) will return list of indexes of  min value row-wise"""
>>print(np.argmin(a,axis=1))
[0,2,2]
Answered By: Nitin Ashutosh
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.