How to sort the pytorch tensors by specific key value?

Question:

I’m new to Pytorch. Given a tensor set, I need to sort these tensors by the key value.
For example,

A = 
[[0.9133, 0.5071, 0.6222, 3.],
 [0.5951, 0.9315, 0.6548, 1.],
 [0.7704, 0.0720, 0.0330, 2.]]

My expected result after sorting is:

A' = 
[[0.5951, 0.9315, 0.6548, 1.],
 [0.7704, 0.0720, 0.0330, 2.],
 [0.9133, 0.5071, 0.6222, 3.]]

I tried to use sorted function in python, but it was time-consuming in my training process.
How to achieve it more efficiently?
Thanks!

Asked By: Humberto

||

Answers:

%%timeit -r 10 -n 10
A[A[:,-1].argsort()]

38.6 µs ± 23 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

%%timeit -r 10 -n 10
sorted(A, key = lambda x: x[-1])

69.6 µs ± 34.8 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

Both output

tensor([[0.5951, 0.9315, 0.6548, 1.0000],
        [0.7704, 0.0720, 0.0330, 2.0000],
        [0.9133, 0.5071, 0.6222, 3.0000]])

Then there is

%%timeit -r 10 -n 10
a, b = torch.sort(A, dim=-2)

The slowest run took 8.45 times longer than the fastest. This could mean that an intermediate result is being cached.
14.3 µs ± 18.1 µs per loop (mean ± std. dev. of 10 runs, 10 loops each)

with a as the sorted tensor and b as the indices

Answered By: warped

I was looking for a solution for ‘torch.sort` that won’t break the order, but no luck. It does not appear to be any direct implementation, It is a bit late but for people like me, this is what I came up with:

A = np.array([[0.9133, 0.5071, 0.6222, 3.],
 [0.5951, 0.9315, 0.6548, 1.],
 [0.7704, 0.0720, 0.0330, 2.]])

tensor_A = torch.tensor(A)

Assuming we initially have a tensor:

%%timeit -r 100 -n 10000
A = tensor_A.cpu().detach().numpy()
torch.tensor(A[A[:,-1].argsort()])

6.71 µs ± 163 ns per loop (mean ± std. dev. of 100 runs, 10000 loops each)

%%timeit -r 100 -n 10000
torch.tensor(sorted(tensor_A.cpu().detach().numpy(), key = lambda x: x[-1]))

6.85 µs ± 212 ns per loop (mean ± std. dev. of 100 runs, 10000 loops each)

%%timeit -r 100 -n 10000
tensor_A[torch.sort(tensor_A[:,-1])[1]]

6.71 µs ± 81 ns per loop (mean ± std. dev. of 100 runs, 10000 loops each)

I compared them as warped did. Initially, the fastest was the first option, however, I realized that when I increased the number of loops, the last option has the same performance. So I would only consider the last solution in long loops, not for occasional use.

Python version: 3.6.8.
PyTorch version: 1.5.0

Answered By: Andrés Tello Urrea