PyTorch slice matrix with vector
Question:
Say I have one matrix and one vector as follows:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
is there a way to slice it x[y]
so the result is:
res = [1, 6, 8]
So basically I take the first element of y
and take the element in x
that corresponds to the first row and the elements’ column.
Answers:
You can specify the corresponding row index as:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
x[range(x.shape[0]), y]
tensor([1, 6, 8])
Advanced indexing in pytorch works just as NumPy's
, i.e the indexing arrays are broadcast together across the axes. So you could do as in FBruzzesi’s answer.
Though similarly to np.take_along_axis
, in pytorch you also have torch.gather
, to take values along a specific axis:
x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])
Say I have one matrix and one vector as follows:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
is there a way to slice it x[y]
so the result is:
res = [1, 6, 8]
So basically I take the first element of y
and take the element in x
that corresponds to the first row and the elements’ column.
You can specify the corresponding row index as:
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
y = torch.tensor([0, 2, 1])
x[range(x.shape[0]), y]
tensor([1, 6, 8])
Advanced indexing in pytorch works just as NumPy's
, i.e the indexing arrays are broadcast together across the axes. So you could do as in FBruzzesi’s answer.
Though similarly to np.take_along_axis
, in pytorch you also have torch.gather
, to take values along a specific axis:
x.gather(1, y.view(-1,1)).view(-1)
# tensor([1, 6, 8])