Get final values from an specific dimension/axis of an arbitrarily dimensioned PyTorch Tensor

Question:

Suppose I had a PyTorch tensor such as:

import torch

x = torch.randn([3, 4, 5])

and I wanted to get a new tensor, with the same number of dimensions, containing everything from the final value of dimension 1. I could just do:

x[:, -1:, :]

However, if x had an arbitrary number of dimensions, and I wanted to get the final values from a specific dimension, what is the best way to do it?

Asked By: Matt Pitkin

||

Answers:

You can use the select function (or equivalent method of a tensor), e.g.,

dim = 1  # the dimension from which to extract the final values

y = x.select(dim, -1).unsqueeze(dim)

where unsqueeze has been used to keep the same number of dimensions as the original tensor.

Answered By: Matt Pitkin

You can use index_select:

torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1))

The output tensor would contain the same number of dimensions as the input. You can use squeeze on the dim to get rid of the extra dimension:

torch.index_select(x, dim=dim, index=torch.tensor(x.size(dim) - 1)).squeeze(dim=dim)

Note: While select returns a view of the input tensor, index_select returns a new tensor.

Example:

In [1]: dim = 1

In [2]: x = torch.randn([3, 4, 5])

In [3]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).shape
Out[3]: torch.Size([3, 1, 5])

In [4]: torch.index_select(x, dim=1, index=torch.tensor(x.size(1) - 1)).squeeze(dim=dim).shape
Out[4]: torch.Size([3, 5])
Answered By: heemayl
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.