Why does dim=1 return row indices in torch.argmax?

Question:

I am working on argmax function of PyTorch which is defined as:

torch.argmax(input, dim=None, keepdim=False)

Consider an example

a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))

Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.

print(a) :   
tensor([[-1.7739,  0.8073,  0.0472, -0.4084],  
        [ 0.6378,  0.6575, -1.2970, -0.0625],  
        [ 1.7970, -1.3463,  0.9011, -0.8704],  
        [ 1.5639,  0.7123,  0.0385,  1.8410]])  

print(torch.argmax(a, dim=1))  
tensor([1, 1, 0, 3])

As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.

Asked By: Programmer

||

Answers:

It’s time to correctly understand how the axis or dim argument work in PyTorch:

tensor dimension


The following example should make sense once you comprehend the above picture:

    |
    v
  dim-0  ---> -----> dim-1 ------> -----> --------> dim-1
    |   [[-1.7739,  0.8073,  0.0472, -0.4084],
    v    [ 0.6378,  0.6575, -1.2970, -0.0625],
    |    [ 1.7970, -1.3463,  0.9011, -0.8704],
    v    [ 1.5639,  0.7123,  0.0385,  1.8410]]
    |
    v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])

Note: dim (short for ‘dimension’) is the torch equivalent of ‘axis’ in NumPy.

Answered By: kmario23

Dimensions are defined as shown in the above excellent answer. I have highlighted the way I understand dimensions in Torch and Numpy (dim and axis respectively) and hope that this will be helpful to others.

Notice that only the specified dimension’s index varies during the argmax operation, and the specified dimension’s index range reduces to a single index once the operation is completed. Let tensor A have M rows and N columns and consider the sum operation for simplicity. The shape of A is (M, N). If dim=0 is specified, then the vectors A[0,:], A[1,:], …, A[M-1,:] are summed elementwise and the result is another tensor with 1 row and N columns. Notice that only the 0th dimension’s indices vary from 0 throughout M-1. Similarly, If dim=1 is specified, then the vectors A[:,0], A[:,1], …, A[:,N-1] are summed elementwise and the result is another tensor with M rows and 1 column.

An example is given below:

>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
        [4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])

In the above sample code, the first sum operation specifies dim=0, therefore A[0,:] and A[1,:], which are [1,2,3] and [4,5,6], are summed and resulted in [5, 7, 9]. When dim=1 was specified, the vectors A[:,0], A[:,1], and A[:2], which are the vectors [1, 4], [2, 5], and [3, 6], are elementwise added to find [6, 15].

Note also that the specified dimension collapses. Again let A have the shape (M, N). If dim=0, then the result will have the shape (1, N), where dimension 0 is reduced from M to 1. Similarly if dim=1, then the result would have the shape (M, 1), where N is reduced to 1. Note also that shapes (1, N) and (M,1) are represented by a single-dimensional tensor with N and M elements respectively.

Answered By: Ismet Sahin