pairwise/rowwise comparison of pytorch tensor

Question:

I have a 2D tensor representing integer coordinates on a grid.
And I would like to check my tensor for any occurences of a specific coordinate (x,y)

A psuedo-code example:

positions = torch.arange(20).repeat(2).view(-1,2)
xy_dst1 = torch.tensor((5,7))
xy_dst2 = torch.tensor((4,5))
positions == xy_dst1 # should give none
positions == xy_dst2 # should give index 2 and 12

My only solution so far is to convert the tensors to lists or tuples and then go through them iteratively, but with the conversions back and forth and the iterations that can’t be a very good solution.
Does anyone know of a better solution that stays in the tensor framework?

Asked By: Tue

||

Answers:

Try

def check(positions, xy):
    return (positions == xy.view(1, 2)).all(dim=1).nonzero()

print(check(positions, xy_dst1))
# Output: tensor([], size=(0, 1), dtype=torch.int64)

print(check(positions, xy_dst2))
# Output:
# tensor([[ 2],
#         [12]])
Answered By: kmkurn
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.