Understanding Pytorch Tensor Slicing

Question:

Let a and b be two PyTorch tensors with a.shape=[A,3] and b.shape=[B,3]. Further b is of type long.

Then I know there are several ways slicing a. For example,

c = a[N1:N2:jump,[0,2]] # N1<N2<A

would return c.shape = [2,2] for N1=1 and N2=4 and jump=2.

But the below should have thrown a error,

c = a[b]

but instead c.shape = [B,3,3].

For example,

a = torch.rand(10,3)
b = torch.rand(20,3).long()
print(a[b].shape) #torch.Size([20, 3, 3])

Can someone explain how the slicing is working for a[b]?

Asked By: Mohit Lamba

||

Answers:

Since b is long torch treats it as indices positions, if it wasn’t type long the above won’t work.

In [29]: a[b]
Out[29]: 
tensor([[[-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[ 0.3707, -0.6549,  1.3003],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.7021, -1.1604, -0.8919],
         [ 0.3707, -0.6549,  1.3003],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.9325,  1.2281,  1.0513],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.7021, -1.1604, -0.8919]],

        [[-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-1.9443, -1.5545,  0.3944],
         [-0.4933,  0.8588,  1.5655]],

        [[-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655],
         [-0.4933,  0.8588,  1.5655]],

        [[-1.9443, -1.5545,  0.3944],
         [-0.7021, -1.1604, -0.8919],
         [-0.4933,  0.8588,  1.5655]]])

In [30]: a
Out[30]: 
tensor([[-0.4933,  0.8588,  1.5655],
        [-1.9443, -1.5545,  0.3944],
        [ 0.3707, -0.6549,  1.3003],
        [ 0.6938, -1.1753, -0.0484],
        [-0.0178, -0.0227,  0.3007],
        [-1.7586, -0.6923,  3.0981],
        [ 1.0726,  0.3889,  1.6468],
        [ 1.7248, -2.6932, -1.2202],
        [-0.9325,  1.2281,  1.0513],
        [-0.7021, -1.1604, -0.8919]])

In [31]: b
Out[31]: 
tensor([[ 0, -1,  0],
        [ 1, -1,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [ 1,  0,  0],
        [ 2,  0,  0],
        [ 0,  0,  0],
        [ 0,  0,  0],
        [-1,  2,  0],
        [-1,  0,  0],
        [ 0,  0,  0],
        [ 0, -1,  0],
        [ 1, -2,  0],
        [ 0,  0,  0],
        [ 0,  0, -1],
        [ 0,  0, -1],
        [-1,  0,  0],
        [ 0,  1,  0],
        [ 0,  0,  0],
        [ 1, -1,  0]])

notice that the first element of a[b] is the first element of a and than the last and again the first which correspond to indices [0, -1, 0] and so since it sample for each entry of be the relevant positions of a you get the [20, 3, 3] shape.

so given that each entry in b is correspond to a valid index in a torch slice a with the given positions. and it does so for each entry of b and concatenate all to a new tensor with the above shape. In case there will be an invalid index (b = torch.randn(20, 3).long() * 10) you will get:

----> 1 a[b]

IndexError: index 10 is out of bounds for dimension 0 with size 10
Answered By: David

Basics

  • When you use a[b], PyTorch is performing advanced indexing.
  • In this case, each row of the tensor b is treated as an index into the first dimension of a, and the corresponding rows of a are returned.
  • Since b has shape [B,3], this means that each row of b is a 3-element index into the first dimension of a. So the result of a[b] will have shape [B,3,d], where d is the number of columns in a.

For example

suppose that b has the following values:

b = torch.tensor([[0,1,2], [3,4,5], [1,2,3]])
  • Then the result of a[b] will be a tensor with shape [3,3,3], where the first dimension corresponds to the three rows of b and the second dimension corresponds to the three indices in each row of b. The third dimension corresponds to the three columns of a.

Here’s how the values are computed:

  • The first row of b is [0,1,2].
  • This means that the first row of a is returned,
  • followed by the second row of a, and then the third row of a.
  • So the first "slice" of the result will be:
[[a[0,0], a[0,1], a[0,2]],
 [a[1,0], a[1,1], a[1,2]],
 [a[2,0], a[2,1], a[2,2]]]

The second row of b is [3,4,5].

  • This means that the fourth row of a is returned,
  • followed by the fifth row of a,
  • and then the sixth row of a.
  • So the second "slice" of the result will be:
[[a[3,0], a[3,1], a[3,2]],
 [a[4,0], a[4,1], a[4,2]],
 [a[5,0], a[5,1], a[5,2]]]

The third row of b is [1,2,3].

  • This means that the second row of a is returned,
  • followed by the third row of a,
  • and then the fourth row of a.
  • So the third "slice" of the result will be:
[[a[1,0], a[1,1], a[1,2]],
 [a[2,0], a[2,1], a[2,2]],
 [a[3,0], a[3,1], a[3,2]]]

All of these slices are concatenated along the first dimension to produce the final result with shape [3,3,3].

Answered By: sogu
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.