PyTorch slice matrix with vector

Question:

Say I have one matrix and one vector as follows:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

is there a way to slice it x[y] so the result is:

res = [1, 6, 8]

So basically I take the first element of y and take the element in x that corresponds to the first row and the elements’ column.

Asked By: Dr. Prof. Patrick

||

Answers:

You can specify the corresponding row index as:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])
Answered By: FBruzzesi

Advanced indexing in pytorch works just as NumPy's, i.e the indexing arrays are broadcast together across the axes. So you could do as in FBruzzesi’s answer.

Though similarly to np.take_along_axis, in pytorch you also have torch.gather, to take values along a specific axis:

x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])
Answered By: yatu
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.