Pytorch find the matching 2×2 tensor in a batch of 2×2 tensors

Question:

I have a 2x2 reference tensor and a batch of candidate 2x2 tensors. I would like to find the closest candidate tensor to the reference tensor by summed euclidean distance over the identically indexed (except for the batch index) elements.

For example:

ref = torch.as_tensor([[1, 2], [3, 4]])
candidates = torch.rand(100, 2, 2)

I would like to find the 2x2 tensor index in candidates that minimizes:

(ref[0][0] - candidates[index][0][0])**2 + 
(ref[0][1] - candidates[index][0][1])**2 + 
(ref[1][0] - candidates[index][1][0])**2 + 
(ref[1][1] - candidates[index][1][1])**2

Ideally, this solution would work for arbitrary dimension reference tensor of size (b, c, d, ...., z) and an arbitrary batch_size of candidate tensors with equal dimensions to the reference tensor (batch_size, b, c, d,..., z)

Asked By: BeginnersMindTruly

||

Answers:

The following line returns the index of the tensor in candidates that minimizes the summation of the element-wise Euclidean distance to ref

In [1]: import torch
In [2]: ref = torch.as_tensor([[1, 2], [3, 4]])
   ...: candidates = torch.rand(100, 2, 2)
In [3]: %timeit torch.argmin(((ref - candidates) ** 2).sum((1, 2)))
16.9 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Answered By: ndrwnaguib

Elaborating on @ndrwnaguib’s answer, it should be rather:

dist = torch.cdist( ref.float().flatten().unsqueeze(0), candidates.flatten(start_dim=1))
print(torch.square( dist ))
torch.argmin( dist )

tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
         22.7513, 16.8489]])

tensor(9)

other options, worth noting:

torch.square(ref.float()- candidates).sum( dim=(1,2) )

tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
         22.7513, 16.8489]])

diff = ref.float()- candidates
torch.einsum( "abc,abc->a" ,diff, diff)

tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
         22.7513, 16.8489]])
Answered By: Alexey Birukov
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.