Pytorch find the matching 2×2 tensor in a batch of 2×2 tensors
Question:
I have a 2x2
reference tensor and a batch of candidate 2x2
tensors. I would like to find the closest candidate tensor to the reference tensor by summed euclidean distance over the identically indexed (except for the batch index) elements.
For example:
ref = torch.as_tensor([[1, 2], [3, 4]])
candidates = torch.rand(100, 2, 2)
I would like to find the 2x2
tensor index
in candidates
that minimizes:
(ref[0][0] - candidates[index][0][0])**2 +
(ref[0][1] - candidates[index][0][1])**2 +
(ref[1][0] - candidates[index][1][0])**2 +
(ref[1][1] - candidates[index][1][1])**2
Ideally, this solution would work for arbitrary dimension reference tensor of size (b, c, d, ...., z)
and an arbitrary batch_size
of candidate tensors with equal dimensions to the reference tensor (batch_size, b, c, d,..., z)
Answers:
The following line returns the index of the tensor in candidates
that minimizes the summation of the element-wise Euclidean distance to ref
In [1]: import torch
In [2]: ref = torch.as_tensor([[1, 2], [3, 4]])
...: candidates = torch.rand(100, 2, 2)
In [3]: %timeit torch.argmin(((ref - candidates) ** 2).sum((1, 2)))
16.9 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Elaborating on @ndrwnaguib’s answer, it should be rather:
dist = torch.cdist( ref.float().flatten().unsqueeze(0), candidates.flatten(start_dim=1))
print(torch.square( dist ))
torch.argmin( dist )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
tensor(9)
other options, worth noting:
torch.square(ref.float()- candidates).sum( dim=(1,2) )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
diff = ref.float()- candidates
torch.einsum( "abc,abc->a" ,diff, diff)
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
I have a 2x2
reference tensor and a batch of candidate 2x2
tensors. I would like to find the closest candidate tensor to the reference tensor by summed euclidean distance over the identically indexed (except for the batch index) elements.
For example:
ref = torch.as_tensor([[1, 2], [3, 4]])
candidates = torch.rand(100, 2, 2)
I would like to find the 2x2
tensor index
in candidates
that minimizes:
(ref[0][0] - candidates[index][0][0])**2 +
(ref[0][1] - candidates[index][0][1])**2 +
(ref[1][0] - candidates[index][1][0])**2 +
(ref[1][1] - candidates[index][1][1])**2
Ideally, this solution would work for arbitrary dimension reference tensor of size (b, c, d, ...., z)
and an arbitrary batch_size
of candidate tensors with equal dimensions to the reference tensor (batch_size, b, c, d,..., z)
The following line returns the index of the tensor in candidates
that minimizes the summation of the element-wise Euclidean distance to ref
In [1]: import torch
In [2]: ref = torch.as_tensor([[1, 2], [3, 4]])
...: candidates = torch.rand(100, 2, 2)
In [3]: %timeit torch.argmin(((ref - candidates) ** 2).sum((1, 2)))
16.9 µs ± 350 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
Elaborating on @ndrwnaguib’s answer, it should be rather:
dist = torch.cdist( ref.float().flatten().unsqueeze(0), candidates.flatten(start_dim=1))
print(torch.square( dist ))
torch.argmin( dist )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
tensor(9)
other options, worth noting:
torch.square(ref.float()- candidates).sum( dim=(1,2) )
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])
diff = ref.float()- candidates
torch.einsum( "abc,abc->a" ,diff, diff)
tensor([[23.3516, 21.8078, 25.5247, 26.3465, 21.3161, 17.7537, 24.1075, 22.4388,
22.7513, 16.8489]])