Pytorch: a similar process to reverse pooling and replicate padding?

Question:

I have a tensor A that has shape (batch_size, width, height). Assume that it has these values:

A = torch.tensor([[[0, 1],
                   [1, 0]]])

I am also given a number K that is a positive integer. Let K=2 in this case. I want to do a process that is similar to reverse pooling and replicate padding. This is the expected output:

B = torch.tensor([[[0, 0, 1, 1],
                   [0, 0, 1, 1],
                   [1, 1, 0, 0],
                   [1, 1, 0, 0]]])

Explanation: for each element in A, we expand it to the matrix of shape (K, K), and put it in the result tensor. We continue to do this with other elements, and let the stride between them equals to the kernel size (that is, K).

How can I do this in PyTorch? Currently, A is a binary mask, but it could be better if I can expand it to non-binary case.

Asked By: Minh-Long Luu

||

Answers:

Square expansion

You can get your desired output by expanding twice:

def dilate(t, k):
  x = t.squeeze()
  x = x.unsqueeze(-1).expand([*x.shape,k])
  x = x.unsqueeze(-1).expand([*x.shape,k])
  x = torch.cat([*x], dim=1)
  x = torch.cat([*x], dim=1)
  x = x.unsqueeze(0)
  return x

B = dilate(A, k)

Resizing / interpolating nearest

If you don’t mind corners potentially ‘bleeding’ in larger expansions (since it uses Euclidean as opposed to Manhattan distance when determining ‘nearest’ points to interpolate), a simpler method is to just resize:

import torchvision.transforms.functional as F

B = F.resize(A, A.shape[-1]*k)

For completeness:

MaxUnpool2d takes in as input the output of MaxPool2d including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero.

Answered By: iacob

You can try these:

Note: The below functions takes a 2D tensor as input. If your tensor A is of shape (1, N, N) i.e., has a (redundant) batch/channel dimension, pass A.squeeze() to func().

Method 1:

This method broadcasted multiplication followed by transpose and reshape operations to achieve the final result.

import torch
import torch.nn as nn

A = torch.tensor([[0, 1, 1], [1, 0, 1], [1, 1, 0]])
K = 3

def func(A, K):
    ones = torch.ones(K, K)
    tmp = ones.unsqueeze(0) * A.view(-1, 1, 1)
    tmp = tmp.reshape(A.shape[0], A.shape[1], K, K)
    res = tmp.transpose(1, 2).reshape(K * A.shape[0], K * A.shape[1])
    return res

Method 2:

From @Shai’s hint in comments, this method repeats the (2D) tensor in channel dimension K**2 times and then uses PixelShuffle() to upscale the row and column by K times.

def pixelshuffle(A, K):
    pixel_shuffle = nn.PixelShuffle(K)
    return pixel_shuffle(A.unsqueeze(0).repeat(K**2, 1, 1).unsqueeze(0)).squeeze(0).squeeze(0)

Since nn.PixelShuffle() takes only 4D tensors as input, unsqueezing after the repeat() was necessary. Also note, since the returned tensor from nn.PixelShuffle() is also 4D, the two squeeze()s followed to ensure we get a 2D tensor as output.

Some example outputs:

A = torch.tensor([[0, 1], [1, 0]])
func(A, 2)
# tensor([[0., 0., 1., 1.],
#         [0., 0., 1., 1.],
#         [1., 1., 0., 0.],
#         [1., 1., 0., 0.]])

pixelshuffle(A, 2)
# tensor([[0, 0, 1, 1],
#         [0, 0, 1, 1],
#         [1, 1, 0, 0],
#         [1, 1, 0, 0]])

Feel free to ask for further clarifications and let me know if it works for you.

Benchmarking:

I benchmarked my answers func() and pixel shuffle() against @iacob’s dilate() function above and found that mine are slightly faster.

A = torch.randint(3, 100, (20, 20))
assert (dilate(A, 5) == func(A, 5)).all()
assert (dilate(A, 5) == pixelshuffle(A, 5)).all()

%timeit dilate(A, 5)
# 142 µs ± 2.54 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit func(A, 5)
# 57.9 µs ± 1.67 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)

%timeit pixelshuffle(A, 5)
# 81.6 µs ± 970 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Answered By: swag2198

For the 2D tensor, I will use repeat_interleave.

>>> import torch
>>> x = torch.tensor([[1, 2], [3, 4]])
>>> x
tensor([[1, 2],
        [3, 4]])
>>> torch.repeat_interleave(x, 2, dim=0)
tensor([[1, 2],
        [1, 2],
        [3, 4],
        [3, 4]])
>>> torch.repeat_interleave(x, 2, dim=1)
tensor([[1, 1, 2, 2],
        [3, 3, 4, 4]])
>>> torch.repeat_interleave(torch.repeat_interleave(x, 2, dim=0), 2, dim=1)
tensor([[1, 1, 2, 2],
        [1, 1, 2, 2],
        [3, 3, 4, 4],
        [3, 3, 4, 4]])
Answered By: Renjie Chen