Row-wise sorting a batch of pytorch tensors by column value

Question:

I would like to sort each row in a bxmxn pytorch tensor (where b represents the batch size) by the k-th column value in each row. So my input tensor is bxmxn, and my output tensor is also bxmxn with the rows of each mxn tensor rearranged based on the k-th column value.

For example, if my original tensor is:

a = torch.as_tensor([[[1, 3, 7, 6], [9, 0, 6, 2], [3, 0, 5, 8]], [[1, 0, 1, 0], [2, 1, 0, 3], [0, 0, 6, 1]]])

My sorted tensor should be:

sorted_dim = 1 # sort by rows, preserving each row
sorted_column = 2 # sort rows on value of 3rd column of each row
sorted_a = torch.as_tensor([[[3, 0, 5, 8], [9, 0, 6, 2], [1, 3, 7, 6]], [[2, 1, 0, 3], [1, 0, 1, 0], [0, 0, 6, 1]]])

Thanks!

Asked By: BeginnersMindTruly

||

Answers:

Try this

a = torch.as_tensor([[[1, 3, 7, 6], [9, 0, 6, 2], [3, 0, 5, 8]], [[1, 0, 1, 0], [2, 1, 0, 3], [0, 0, 6, 1]]])

b=torch.argsort(a[:,:,2])
sorted_a=torch.stack([a[i,b[i],:] for i in range(a.shape[0])] )
sorted_a

output:

tensor([[[3, 0, 5, 8],
         [9, 0, 6, 2],
         [1, 3, 7, 6]],

        [[2, 1, 0, 3],
         [1, 0, 1, 0],
         [0, 0, 6, 1]]])
Answered By: Nelson aka SpOOKY
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.