Index matrix but return a list of lists Pytorch

Question:

I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:

R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]

This makes output as tensor([1, 4, 5, 6]). However, I would like to have [[1], [4,5,6]] or [tensor(1), tensor([4,5,6])].

I now that it could be done with a loop and ussing .append(). However, I would like to avoid the use of any loop to make it faster if R and mask are very big.

Is there any way to do that in Python without any loop?

Asked By: Josemi

||

Answers:

You can try using PyTorch’s indexing and broadcasting:

output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])

This splits the resulting tensor based on True values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split() method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:

output = [tensor.tolist() for tensor in output]

This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False values removed.

Answered By: tmc