Reshaping the dimension of a tensor in PyTorch

Question:

There is a tensor with the shape of [b, nt*nh*nw, dim]. The values of nt, nh, and nw are in hand. How can I reshape this tensor to the form of [b, dim, nt, nh, nw]? For example, how it is possible to reshape [2, 3x2x4, 512] to [2,512,3,2,4]?

Asked By: dtr43

||

Answers:

It all depends on your data layout in memory.

However, assuming nt, nh, and nw are in the correct ordering in your underlying data tensor then you can do so by permuting and reshaping your tensor.

First swap dimensions to place dim as the 2nd axis using torch.transpose or torch.permute. Then reshape the tensor to the desired shape with torch.view or torch.reshape:

>>> x.transpose(1,2).view(b, dim, nt, nh, nw)
Answered By: Ivan
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.