Understanding Pytorch Tensor Slicing
Question:
Let a
and b
be two PyTorch tensors with a.shape=[A,3]
and b.shape=[B,3]
. Further b
is of type long
.
Then I know there are several ways slicing a
. For example,
c = a[N1:N2:jump,[0,2]] # N1<N2<A
would return c.shape = [2,2]
for N1=1 and N2=4 and jump=2.
But the below should have thrown a error,
c = a[b]
but instead c.shape = [B,3,3]
.
For example,
a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])
Can someone explain how the slicing is working for a[b]
?
Answers:
Since b is long torch treats it as indices positions, if it wasn’t type long the above won’t work.
In [29]: a[b]
Out[29]:
tensor([[[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[ 0.3707, -0.6549, 1.3003],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.7021, -1.1604, -0.8919],
[ 0.3707, -0.6549, 1.3003],
[-0.4933, 0.8588, 1.5655]],
[[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.9325, 1.2281, 1.0513],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919]],
[[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-1.9443, -1.5545, 0.3944],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]]])
In [30]: a
Out[30]:
tensor([[-0.4933, 0.8588, 1.5655],
[-1.9443, -1.5545, 0.3944],
[ 0.3707, -0.6549, 1.3003],
[ 0.6938, -1.1753, -0.0484],
[-0.0178, -0.0227, 0.3007],
[-1.7586, -0.6923, 3.0981],
[ 1.0726, 0.3889, 1.6468],
[ 1.7248, -2.6932, -1.2202],
[-0.9325, 1.2281, 1.0513],
[-0.7021, -1.1604, -0.8919]])
In [31]: b
Out[31]:
tensor([[ 0, -1, 0],
[ 1, -1, 0],
[ 0, 0, 0],
[ 0, 0, 0],
[ 1, 0, 0],
[ 2, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
[-1, 2, 0],
[-1, 0, 0],
[ 0, 0, 0],
[ 0, -1, 0],
[ 1, -2, 0],
[ 0, 0, 0],
[ 0, 0, -1],
[ 0, 0, -1],
[-1, 0, 0],
[ 0, 1, 0],
[ 0, 0, 0],
[ 1, -1, 0]])
notice that the first element of a[b]
is the first element of a
and than the last and again the first which correspond to indices [0, -1, 0]
and so since it sample for each entry of be the relevant positions of a
you get the [20, 3, 3]
shape.
so given that each entry in b
is correspond to a valid index in a
torch slice a
with the given positions. and it does so for each entry of b
and concatenate all to a new tensor with the above shape. In case there will be an invalid index (b = torch.randn(20, 3).long() * 10
) you will get:
----> 1 a[b]
IndexError: index 10 is out of bounds for dimension 0 with size 10
Basics
- When you use a[b], PyTorch is performing advanced indexing.
- In this case, each row of the tensor b is treated as an index into the first dimension of a, and the corresponding rows of a are returned.
- Since b has shape [B,3], this means that each row of b is a 3-element index into the first dimension of a. So the result of a[b] will have shape [B,3,d], where d is the number of columns in a.
For example
suppose that b has the following values:
b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
- Then the result of a[b] will be a tensor with shape [3,3,3], where the first dimension corresponds to the three rows of b and the second dimension corresponds to the three indices in each row of b. The third dimension corresponds to the three columns of a.
Here’s how the values are computed:
- The first row of b is [0,1,2].
- This means that the first row of a is returned,
- followed by the second row of a, and then the third row of a.
- So the first "slice" of the result will be:
[[a[0,0], a[0,1], a[0,2]],
[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]]]
The second row of b is [3,4,5].
- This means that the fourth row of a is returned,
- followed by the fifth row of a,
- and then the sixth row of a.
- So the second "slice" of the result will be:
[[a[3,0], a[3,1], a[3,2]],
[a[4,0], a[4,1], a[4,2]],
[a[5,0], a[5,1], a[5,2]]]
The third row of b is [1,2,3].
- This means that the second row of a is returned,
- followed by the third row of a,
- and then the fourth row of a.
- So the third "slice" of the result will be:
[[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]],
[a[3,0], a[3,1], a[3,2]]]
All of these slices are concatenated along the first dimension to produce the final result with shape [3,3,3].
Let a
and b
be two PyTorch tensors with a.shape=[A,3]
and b.shape=[B,3]
. Further b
is of type long
.
Then I know there are several ways slicing a
. For example,
c = a[N1:N2:jump,[0,2]] # N1<N2<A
would return c.shape = [2,2]
for N1=1 and N2=4 and jump=2.
But the below should have thrown a error,
c = a[b]
but instead c.shape = [B,3,3]
.
For example,
a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])
Can someone explain how the slicing is working for a[b]
?
Since b is long torch treats it as indices positions, if it wasn’t type long the above won’t work.
In [29]: a[b]
Out[29]:
tensor([[[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[ 0.3707, -0.6549, 1.3003],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.7021, -1.1604, -0.8919],
[ 0.3707, -0.6549, 1.3003],
[-0.4933, 0.8588, 1.5655]],
[[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.9325, 1.2281, 1.0513],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.7021, -1.1604, -0.8919]],
[[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-1.9443, -1.5545, 0.3944],
[-0.4933, 0.8588, 1.5655]],
[[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655],
[-0.4933, 0.8588, 1.5655]],
[[-1.9443, -1.5545, 0.3944],
[-0.7021, -1.1604, -0.8919],
[-0.4933, 0.8588, 1.5655]]])
In [30]: a
Out[30]:
tensor([[-0.4933, 0.8588, 1.5655],
[-1.9443, -1.5545, 0.3944],
[ 0.3707, -0.6549, 1.3003],
[ 0.6938, -1.1753, -0.0484],
[-0.0178, -0.0227, 0.3007],
[-1.7586, -0.6923, 3.0981],
[ 1.0726, 0.3889, 1.6468],
[ 1.7248, -2.6932, -1.2202],
[-0.9325, 1.2281, 1.0513],
[-0.7021, -1.1604, -0.8919]])
In [31]: b
Out[31]:
tensor([[ 0, -1, 0],
[ 1, -1, 0],
[ 0, 0, 0],
[ 0, 0, 0],
[ 1, 0, 0],
[ 2, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0],
[-1, 2, 0],
[-1, 0, 0],
[ 0, 0, 0],
[ 0, -1, 0],
[ 1, -2, 0],
[ 0, 0, 0],
[ 0, 0, -1],
[ 0, 0, -1],
[-1, 0, 0],
[ 0, 1, 0],
[ 0, 0, 0],
[ 1, -1, 0]])
notice that the first element of a[b]
is the first element of a
and than the last and again the first which correspond to indices [0, -1, 0]
and so since it sample for each entry of be the relevant positions of a
you get the [20, 3, 3]
shape.
so given that each entry in b
is correspond to a valid index in a
torch slice a
with the given positions. and it does so for each entry of b
and concatenate all to a new tensor with the above shape. In case there will be an invalid index (b = torch.randn(20, 3).long() * 10
) you will get:
----> 1 a[b]
IndexError: index 10 is out of bounds for dimension 0 with size 10
Basics
- When you use a[b], PyTorch is performing advanced indexing.
- In this case, each row of the tensor b is treated as an index into the first dimension of a, and the corresponding rows of a are returned.
- Since b has shape [B,3], this means that each row of b is a 3-element index into the first dimension of a. So the result of a[b] will have shape [B,3,d], where d is the number of columns in a.
For example
suppose that b has the following values:
b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
- Then the result of a[b] will be a tensor with shape [3,3,3], where the first dimension corresponds to the three rows of b and the second dimension corresponds to the three indices in each row of b. The third dimension corresponds to the three columns of a.
Here’s how the values are computed:
- The first row of b is [0,1,2].
- This means that the first row of a is returned,
- followed by the second row of a, and then the third row of a.
- So the first "slice" of the result will be:
[[a[0,0], a[0,1], a[0,2]],
[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]]]
The second row of b is [3,4,5].
- This means that the fourth row of a is returned,
- followed by the fifth row of a,
- and then the sixth row of a.
- So the second "slice" of the result will be:
[[a[3,0], a[3,1], a[3,2]],
[a[4,0], a[4,1], a[4,2]],
[a[5,0], a[5,1], a[5,2]]]
The third row of b is [1,2,3].
- This means that the second row of a is returned,
- followed by the third row of a,
- and then the fourth row of a.
- So the third "slice" of the result will be:
[[a[1,0], a[1,1], a[1,2]],
[a[2,0], a[2,1], a[2,2]],
[a[3,0], a[3,1], a[3,2]]]
All of these slices are concatenated along the first dimension to produce the final result with shape [3,3,3].