Pytorch: a similar process to reverse pooling and replicate padding?
Question:
I have a tensor A
that has shape (batch_size, width, height)
. Assume that it has these values:
A = torch.tensor([[[0, 1],
[1, 0]]])
I am also given a number K
that is a positive integer. Let K=2
in this case. I want to do a process that is similar to reverse pooling and replicate padding. This is the expected output:
B = torch.tensor([[[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]])
Explanation: for each element in A
, we expand it to the matrix of shape (K, K)
, and put it in the result tensor. We continue to do this with other elements, and let the stride between them equals to the kernel size (that is, K
).
How can I do this in PyTorch? Currently, A
is a binary mask, but it could be better if I can expand it to non-binary case.
Answers:
Square expansion
You can get your desired output by expanding twice:
def dilate(t, k):
x = t.squeeze()
x = x.unsqueeze(-1).expand([*x.shape,k])
x = x.unsqueeze(-1).expand([*x.shape,k])
x = torch.cat([*x], dim=1)
x = torch.cat([*x], dim=1)
x = x.unsqueeze(0)
return x
B = dilate(A, k)
Resizing / interpolating nearest
If you don’t mind corners potentially ‘bleeding’ in larger expansions (since it uses Euclidean as opposed to Manhattan distance when determining ‘nearest’ points to interpolate), a simpler method is to just resize
:
import torchvision.transforms.functional as F
B = F.resize(A, A.shape[-1]*k)
For completeness:
MaxUnpool2d
takes in as input the output of MaxPool2d
including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero.
You can try these:
Note: The below functions takes a 2D tensor as input. If your tensor A
is of shape (1, N, N) i.e., has a (redundant) batch/channel dimension, pass A.squeeze()
to func()
.
Method 1:
This method broadcasted multiplication followed by transpose and reshape operations to achieve the final result.
import torch
import torch.nn as nn
A = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
K = 3
def func(A, K):
ones = torch.ones(K, K)
tmp = ones.unsqueeze(0) * A.view(-1, 1, 1)
tmp = tmp.reshape(A.shape[0], A.shape[1], K, K)
res = tmp.transpose(1, 2).reshape(K * A.shape[0], K * A.shape[1])
return res
Method 2:
From @Shai’s hint in comments, this method repeats the (2D) tensor in channel dimension K**2
times and then uses PixelShuffle() to upscale the row and column by K
times.
def pixelshuffle(A, K):
pixel_shuffle = nn.PixelShuffle(K)
return pixel_shuffle(A.unsqueeze(0).repeat(K**2, 1, 1).unsqueeze(0)).squeeze(0).squeeze(0)
Since nn.PixelShuffle()
takes only 4D tensors as input, unsqueezing after the repeat()
was necessary. Also note, since the returned tensor from nn.PixelShuffle()
is also 4D, the two squeeze()
s followed to ensure we get a 2D tensor as output.
Some example outputs:
A = torch.tensor([[0, 1], [1, 0]])
func(A, 2)
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.],
# [1., 1., 0., 0.],
# [1., 1., 0., 0.]])
pixelshuffle(A, 2)
# tensor([[0, 0, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 0, 0],
# [1, 1, 0, 0]])
Feel free to ask for further clarifications and let me know if it works for you.
Benchmarking:
I benchmarked my answers func()
and pixel shuffle()
against @iacob’s dilate()
function above and found that mine are slightly faster.
A = torch.randint(3, 100, (20, 20))
assert (dilate(A, 5) == func(A, 5)).all()
assert (dilate(A, 5) == pixelshuffle(A, 5)).all()
%timeit dilate(A, 5)
# 142 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit func(A, 5)
# 57.9 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit pixelshuffle(A, 5)
# 81.6 µs ± 970 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
For the 2D tensor, I will use repeat_interleave
.
>>> import torch
>>> x = torch.tensor([[1, 2], [3, 4]])
>>> x
tensor([[1, 2],
[3, 4]])
>>> torch.repeat_interleave(x, 2, dim=0)
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
>>> torch.repeat_interleave(x, 2, dim=1)
tensor([[1, 1, 2, 2],
[3, 3, 4, 4]])
>>> torch.repeat_interleave(torch.repeat_interleave(x, 2, dim=0), 2, dim=1)
tensor([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]])
I have a tensor A
that has shape (batch_size, width, height)
. Assume that it has these values:
A = torch.tensor([[[0, 1],
[1, 0]]])
I am also given a number K
that is a positive integer. Let K=2
in this case. I want to do a process that is similar to reverse pooling and replicate padding. This is the expected output:
B = torch.tensor([[[0, 0, 1, 1],
[0, 0, 1, 1],
[1, 1, 0, 0],
[1, 1, 0, 0]]])
Explanation: for each element in A
, we expand it to the matrix of shape (K, K)
, and put it in the result tensor. We continue to do this with other elements, and let the stride between them equals to the kernel size (that is, K
).
How can I do this in PyTorch? Currently, A
is a binary mask, but it could be better if I can expand it to non-binary case.
Square expansion
You can get your desired output by expanding twice:
def dilate(t, k):
x = t.squeeze()
x = x.unsqueeze(-1).expand([*x.shape,k])
x = x.unsqueeze(-1).expand([*x.shape,k])
x = torch.cat([*x], dim=1)
x = torch.cat([*x], dim=1)
x = x.unsqueeze(0)
return x
B = dilate(A, k)
Resizing / interpolating nearest
If you don’t mind corners potentially ‘bleeding’ in larger expansions (since it uses Euclidean as opposed to Manhattan distance when determining ‘nearest’ points to interpolate), a simpler method is to just resize
:
import torchvision.transforms.functional as F
B = F.resize(A, A.shape[-1]*k)
For completeness:
MaxUnpool2d
takes in as input the output ofMaxPool2d
including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero.
You can try these:
Note: The below functions takes a 2D tensor as input. If your tensor A
is of shape (1, N, N) i.e., has a (redundant) batch/channel dimension, pass A.squeeze()
to func()
.
Method 1:
This method broadcasted multiplication followed by transpose and reshape operations to achieve the final result.
import torch
import torch.nn as nn
A = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
K = 3
def func(A, K):
ones = torch.ones(K, K)
tmp = ones.unsqueeze(0) * A.view(-1, 1, 1)
tmp = tmp.reshape(A.shape[0], A.shape[1], K, K)
res = tmp.transpose(1, 2).reshape(K * A.shape[0], K * A.shape[1])
return res
Method 2:
From @Shai’s hint in comments, this method repeats the (2D) tensor in channel dimension K**2
times and then uses PixelShuffle() to upscale the row and column by K
times.
def pixelshuffle(A, K):
pixel_shuffle = nn.PixelShuffle(K)
return pixel_shuffle(A.unsqueeze(0).repeat(K**2, 1, 1).unsqueeze(0)).squeeze(0).squeeze(0)
Since nn.PixelShuffle()
takes only 4D tensors as input, unsqueezing after the repeat()
was necessary. Also note, since the returned tensor from nn.PixelShuffle()
is also 4D, the two squeeze()
s followed to ensure we get a 2D tensor as output.
Some example outputs:
A = torch.tensor([[0, 1], [1, 0]])
func(A, 2)
# tensor([[0., 0., 1., 1.],
# [0., 0., 1., 1.],
# [1., 1., 0., 0.],
# [1., 1., 0., 0.]])
pixelshuffle(A, 2)
# tensor([[0, 0, 1, 1],
# [0, 0, 1, 1],
# [1, 1, 0, 0],
# [1, 1, 0, 0]])
Feel free to ask for further clarifications and let me know if it works for you.
Benchmarking:
I benchmarked my answers func()
and pixel shuffle()
against @iacob’s dilate()
function above and found that mine are slightly faster.
A = torch.randint(3, 100, (20, 20))
assert (dilate(A, 5) == func(A, 5)).all()
assert (dilate(A, 5) == pixelshuffle(A, 5)).all()
%timeit dilate(A, 5)
# 142 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit func(A, 5)
# 57.9 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
%timeit pixelshuffle(A, 5)
# 81.6 µs ± 970 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
For the 2D tensor, I will use repeat_interleave
.
>>> import torch
>>> x = torch.tensor([[1, 2], [3, 4]])
>>> x
tensor([[1, 2],
[3, 4]])
>>> torch.repeat_interleave(x, 2, dim=0)
tensor([[1, 2],
[1, 2],
[3, 4],
[3, 4]])
>>> torch.repeat_interleave(x, 2, dim=1)
tensor([[1, 1, 2, 2],
[3, 3, 4, 4]])
>>> torch.repeat_interleave(torch.repeat_interleave(x, 2, dim=0), 2, dim=1)
tensor([[1, 1, 2, 2],
[1, 1, 2, 2],
[3, 3, 4, 4],
[3, 3, 4, 4]])