Tensor slicing: tensorflow vs pytorch

Question:

I was testing this simple slicing operation in TF and PyTorch which should match in both

import tensorflow as tf
import numpy as np
import torch
tf_x = tf.random.uniform((4, 64, 64, 3))

pt_x = torch.Tensor(tf_x.numpy())
pt_x = pt_x.permute(0, 3, 1, 2)
# slicing operation
print(np.any(pt_x[:, :, 1:].permute(0, 2, 3, 1).numpy() - tf_x[:, 1:].numpy())) 
# > False

pt_x = torch.Tensor(tf_x.numpy())
b, h, w, c = pt_x.shape
pt_x = pt_x.reshape((b, c, h, w))
print(np.any(pt_x.view(b, h, w, c).numpy() - tf_x.numpy())) # False
print(np.any(pt_x[:, :, 1:].reshape(4, 63, 64, 3).numpy() - tf_x[:, 1:].numpy())) 
# > True

In the last line lies the problem. Both PyTorch and TF should lead to the same value but they don’t. Is this discrepancy caused when I try to reshape the tensor?

Asked By: Abhijay Ghildyal

||

Answers:

On one hand, you have pt_x equal to tf_x, use np.isclose to verify:

>>> np.isclose(pt_x.view(b, h, w, c).numpy(), tf_x.numpy()).all()
True

On the other hand, you are slicing both tensors differently: pt_x[:, :, 1:] removes the first element along axis=2, while tf_x[:, 1:] removed the first element along axis=1. Therefore you end up with two distinct elements with overlapping values, like tf_x[:, 1:][0,-1,-1,-1] and pt_x[0,-1,-1,-1].

Also keep in mind tensor layouts are different in Tensorflow and PyTorch, while the former uses channel last layout, the latter does channel first. The operation needed between those two is a permutation (not a reshape).

Answered By: Ivan