Why does dim=1 return row indices in torch.argmax?
Question:
I am working on argmax
function of PyTorch which is defined as:
torch.argmax(input, dim=None, keepdim=False)
Consider an example
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.
Answers:
It’s time to correctly understand how the axis
or dim
argument work in PyTorch:
The following example should make sense once you comprehend the above picture:
|
v
dim-0 ---> -----> dim-1 ------> -----> --------> dim-1
| [[-1.7739, 0.8073, 0.0472, -0.4084],
v [ 0.6378, 0.6575, -1.2970, -0.0625],
| [ 1.7970, -1.3463, 0.9011, -0.8704],
v [ 1.5639, 0.7123, 0.0385, 1.8410]]
|
v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
Note: dim
(short for ‘dimension’) is the torch equivalent of ‘axis’ in NumPy.
Dimensions are defined as shown in the above excellent answer. I have highlighted the way I understand dimensions in Torch and Numpy (dim and axis respectively) and hope that this will be helpful to others.
Notice that only the specified dimension’s index varies during the argmax operation, and the specified dimension’s index range reduces to a single index once the operation is completed. Let tensor A have M rows and N columns and consider the sum operation for simplicity. The shape of A is (M, N). If dim=0 is specified, then the vectors A[0,:]
, A[1,:]
, …, A[M-1,:]
are summed elementwise and the result is another tensor with 1 row and N columns. Notice that only the 0th dimension’s indices vary from 0 throughout M-1. Similarly, If dim=1 is specified, then the vectors A[:,0]
, A[:,1]
, …, A[:,N-1]
are summed elementwise and the result is another tensor with M rows and 1 column.
An example is given below:
>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
[4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])
In the above sample code, the first sum operation specifies dim=0, therefore A[0,:]
and A[1,:]
, which are [1,2,3]
and [4,5,6]
, are summed and resulted in [5, 7, 9]
. When dim=1 was specified, the vectors A[:,0]
, A[:,1]
, and A[:2]
, which are the vectors [1, 4]
, [2, 5]
, and [3, 6]
, are elementwise added to find [6, 15]
.
Note also that the specified dimension collapses. Again let A have the shape (M, N)
. If dim=0, then the result will have the shape (1, N)
, where dimension 0 is reduced from M to 1. Similarly if dim=1, then the result would have the shape (M, 1)
, where N is reduced to 1. Note also that shapes (1, N) and (M,1) are represented by a single-dimensional tensor with N and M elements respectively.
I am working on argmax
function of PyTorch which is defined as:
torch.argmax(input, dim=None, keepdim=False)
Consider an example
a = torch.randn(4, 4)
print(a)
print(torch.argmax(a, dim=1))
Here when I use dim=1 instead of searching column vectors, the function searches for row vectors as shown below.
print(a) :
tensor([[-1.7739, 0.8073, 0.0472, -0.4084],
[ 0.6378, 0.6575, -1.2970, -0.0625],
[ 1.7970, -1.3463, 0.9011, -0.8704],
[ 1.5639, 0.7123, 0.0385, 1.8410]])
print(torch.argmax(a, dim=1))
tensor([1, 1, 0, 3])
As far as my assumption goes dim = 0 represents rows and dim =1 represent columns.
It’s time to correctly understand how the axis
or dim
argument work in PyTorch:
The following example should make sense once you comprehend the above picture:
| v dim-0 ---> -----> dim-1 ------> -----> --------> dim-1 | [[-1.7739, 0.8073, 0.0472, -0.4084], v [ 0.6378, 0.6575, -1.2970, -0.0625], | [ 1.7970, -1.3463, 0.9011, -0.8704], v [ 1.5639, 0.7123, 0.0385, 1.8410]] | v
# argmax (indices where max values are present) along dimension-1
In [215]: torch.argmax(a, dim=1)
Out[215]: tensor([1, 1, 0, 3])
Note: dim
(short for ‘dimension’) is the torch equivalent of ‘axis’ in NumPy.
Dimensions are defined as shown in the above excellent answer. I have highlighted the way I understand dimensions in Torch and Numpy (dim and axis respectively) and hope that this will be helpful to others.
Notice that only the specified dimension’s index varies during the argmax operation, and the specified dimension’s index range reduces to a single index once the operation is completed. Let tensor A have M rows and N columns and consider the sum operation for simplicity. The shape of A is (M, N). If dim=0 is specified, then the vectors A[0,:]
, A[1,:]
, …, A[M-1,:]
are summed elementwise and the result is another tensor with 1 row and N columns. Notice that only the 0th dimension’s indices vary from 0 throughout M-1. Similarly, If dim=1 is specified, then the vectors A[:,0]
, A[:,1]
, …, A[:,N-1]
are summed elementwise and the result is another tensor with M rows and 1 column.
An example is given below:
>>> A = torch.tensor([[1,2,3], [4,5,6]])
>>> A
tensor([[1, 2, 3],
[4, 5, 6]])
>>> S0 = torch.sum(A, dim = 0)
>>> S0
tensor([5, 7, 9])
>>> S1 = torch.sum(A, dim = 1)
>>> S1
tensor([ 6, 15])
In the above sample code, the first sum operation specifies dim=0, therefore A[0,:]
and A[1,:]
, which are [1,2,3]
and [4,5,6]
, are summed and resulted in [5, 7, 9]
. When dim=1 was specified, the vectors A[:,0]
, A[:,1]
, and A[:2]
, which are the vectors [1, 4]
, [2, 5]
, and [3, 6]
, are elementwise added to find [6, 15]
.
Note also that the specified dimension collapses. Again let A have the shape (M, N)
. If dim=0, then the result will have the shape (1, N)
, where dimension 0 is reduced from M to 1. Similarly if dim=1, then the result would have the shape (M, 1)
, where N is reduced to 1. Note also that shapes (1, N) and (M,1) are represented by a single-dimensional tensor with N and M elements respectively.