Keeping gradients while rearranging data in a tensor, with pytorch

Question:

I have a scheme where I store a matrix with zeros on the diagonals as a vector. I want to later on optimize over that vector, so I require gradient tracking.
My challenge is to reshape between the two.

I want – for domain specific reasons – keep the order of data in the matrix so that transposed elements of the W matrix next to each other in the vector form.

The size of the W matrix is subject to change, so I start with enumering items in the top-left part of the matrix, and continue outwards.

I have come up with two ways to do this. See code snippet.

import torch
import torch.sparse

w = torch.tensor([10,11,12,13,14,15],requires_grad=True,dtype=torch.float)
i = torch.LongTensor([
    [0, 1,0],
    [1, 0,1], 
    [0, 2,2],
    [2, 0,3],
    [1, 2,4],
    [2, 1,5],
])
v = torch.FloatTensor([1,      1,      1 ,1,1,1   ])
reshaper = torch.sparse.FloatTensor(i.t(), v, torch.Size([3,3,6])).to_dense()
W_mat_with_reshaper = reshaper @ w
W_mat_directly = torch.tensor([
  [0,    w[0],  w[2],],
  [w[1],    0,  w[4],],
  [w[3], w[5],     0,],
])
print(W_mat_with_reshaper)
print(W_mat_directly)

and this gives output


tensor([[ 0., 10., 12.],
        [11.,  0., 14.],
        [13., 15.,  0.]], grad_fn=<UnsafeViewBackward>)
tensor([[ 0., 10., 12.],
        [11.,  0., 14.],
        [13., 15.,  0.]])

As you can see, the direct way to reshape the vector into a matrix does not have a grad function, but the multiply-with-a-reshaper-tensor does. Creating the reshaper-tensor seems like it will be a hassle, but on the other hand, manually writing the matrix for is also infeasible.

Is there a way to do arbitrary reshapes in pytorch that keeps grack of gradients?

Asked By: LudvigH

||

Answers:

Instead of constructing W_mat_directly from the elements of w, try assigning w into W:

W_mat_directly = torch.zeros((3, 3), dtype=w.dtype)
W_mat_directly[(0, 0, 1, 1, 2, 2), (1, 2, 0, 2, 0, 1)] = w

You’ll get

tensor([[ 0., 10., 11.],
        [12.,  0., 13.],
        [14., 15.,  0.]], grad_fn=<IndexPutBackward>)
Answered By: Shai

You can use the facts that:

  • slicing preserves gradients while indexing doesn’t;
  • concatenation preserves gradients while tensor creation doesn’t.
tensor0 = torch.zeros(1)
W_mat_directly = torch.concatenate(
    [tensor0, w[0:1], w[1:2], w[1:2], tensor0, w[4:5], w[3:4], w[5:6], tensor0]
).reshape(3,3)

With this approach you can apply arbitrary functions to the elements of the initial tensor w.

Answered By: LGrementieri
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.