Check if each element of a tensor is contained in a list
Question:
Say I have a tensor A
and a container of values vals
. Is there a clean way of returning a Boolean tensor of the same shape as A
with each element being whether that element of A
is contained within vals
? e.g:
A = torch.tensor([[1,2,3],
[4,5,6]])
vals = [1,5]
# Desired output
torch.tensor([[True,False,False],
[False,True,False]])
Answers:
You can achieve this with a for loop:
sum(A==i for i in B).bool()
You can simply do like this:
result = A.apply_(lambda x: x in vals).bool()
Then result
will contain this tensor:
tensor([[ True, False, False],
[False, True, False]])
Here I simply used a lambda function and the apply_ method that you can find in the official documentation.
[list(map(lambda x: x in vals, thelist)) for thelist in A]
Use torch.isin
method is the most convinient way. It’s simple as follows: torch.isin(A, vals)
Say I have a tensor A
and a container of values vals
. Is there a clean way of returning a Boolean tensor of the same shape as A
with each element being whether that element of A
is contained within vals
? e.g:
A = torch.tensor([[1,2,3],
[4,5,6]])
vals = [1,5]
# Desired output
torch.tensor([[True,False,False],
[False,True,False]])
You can achieve this with a for loop:
sum(A==i for i in B).bool()
You can simply do like this:
result = A.apply_(lambda x: x in vals).bool()
Then result
will contain this tensor:
tensor([[ True, False, False],
[False, True, False]])
Here I simply used a lambda function and the apply_ method that you can find in the official documentation.
[list(map(lambda x: x in vals, thelist)) for thelist in A]
Use torch.isin
method is the most convinient way. It’s simple as follows: torch.isin(A, vals)