Index matrix but return a list of lists Pytorch
Question:
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output
as tensor([1, 4, 5, 6])
. However, I would like to have [[1], [4,5,6]]
or [tensor(1), tensor([4,5,6])]
.
I now that it could be done with a loop and ussing .append()
. However, I would like to avoid the use of any loop to make it faster if R
and mask
are very big.
Is there any way to do that in Python without any loop?
Answers:
You can try using PyTorch’s indexing and broadcasting:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
This splits the resulting tensor based on True
values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split()
method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:
output = [tensor.tolist() for tensor in output]
This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False
values removed.
I have a 2-dimensional tensor and I would like to index it so that the result is a list of lists. For example:
R = torch.tensor([[1,2,3], [4,5,6]])
mask = torch.tensor([[1,0,0],[1,1,1]], dtype=torch.bool)
output = R[mask]
This makes output
as tensor([1, 4, 5, 6])
. However, I would like to have [[1], [4,5,6]]
or [tensor(1), tensor([4,5,6])]
.
I now that it could be done with a loop and ussing .append()
. However, I would like to avoid the use of any loop to make it faster if R
and mask
are very big.
Is there any way to do that in Python without any loop?
You can try using PyTorch’s indexing and broadcasting:
output = R[mask].split([mask[i].sum() for i in range(mask.shape[0])])
This splits the resulting tensor based on True
values in each row of the mask, which gives a list of tensors where each element corresponds to a row of the original tensor.
The .split()
method returns a tuple of tensors, so you can cast each element of the tuple to a list or a tensor as follows:
output = [tensor.tolist() for tensor in output]
This will give a list of lists, wherein each sublist corresponds to a row of the original tensor with the False
values removed.