Use torch.gather to select images from tensor

Question:

I have a tensor of images of size (3600, 32, 32, 3) and I have a multi hot tensor [0, 1, 1, 0, …] of size (3600, 1). I am looking to basically selecting images that correspond to a 1 in the multi hot tensor. I am trying to understand how to use torch.gather:

tensorA.gather(0, tensorB)

Gives me issues with dims and I can’t properly understand how to reshape them.

Asked By: MichaelMMeskhi

||

Answers:

When using torch.gather, the dimension of input and dimension of index must be the same. And the index is not a multi hot tensor, but the location of the desired value.

You can slice the tensor by using the index of the multi hot tensor. The fourth line finds the index with a value of 1 in the multi hot tensor. The fifth line slices the image based on the index.

Code:

tensorA = torch.randn(4, 32, 32, 3)
tensorB = torch.tensor([0,1,1,0])

tensorB_where = torch.where(tensorB == 1)[0]
result = tensorA[tensorB_where]
Answered By: core_not_dumped
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.