Is there an equivalent PyTorch function for `tf.nn.space_to_depth`
Question:
As the title says, is there an equivalent PyTorch function for tf.nn.space_to_depth?
Answers:
While torch.nn.functional.pixel_shuffle does exactly what tf.nn.depth_to_space
does, PyTorch doesn’t have any function to do the inverse operation similar to tf.nn.space_to_depth
.
That being said, it is easy to implement space_to_depth
using torch.nn.functional.unfold.
def space_to_depth(x, block_size):
n, c, h, w = x.size()
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
return unfolded_x.view(n, c * block_size ** 2, h // block_size, w // block_size)
Actually, @Priyatham’s unfold is not correct.
Unfold
will enlarge/expand channel by block_size * block_size
, but the space-to-depth
requires duplicate channel by block_size * block_size
(means copy channel by block_size * block_size
)
So, the right way is to use eniops.rearrange()
:
result = einops.rearrange(x, 'b c (h p1) (w p2) -> b (p1 p2) h w', p1=block_size, p2=block_size)
You can follow Tresnet Github official sourceode
https://github.com/Alibaba-MIIL/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py
As the title says, is there an equivalent PyTorch function for tf.nn.space_to_depth?
While torch.nn.functional.pixel_shuffle does exactly what tf.nn.depth_to_space
does, PyTorch doesn’t have any function to do the inverse operation similar to tf.nn.space_to_depth
.
That being said, it is easy to implement space_to_depth
using torch.nn.functional.unfold.
def space_to_depth(x, block_size):
n, c, h, w = x.size()
unfolded_x = torch.nn.functional.unfold(x, block_size, stride=block_size)
return unfolded_x.view(n, c * block_size ** 2, h // block_size, w // block_size)
Actually, @Priyatham’s unfold is not correct.
Unfold
will enlarge/expand channel by block_size * block_size
, but the space-to-depth
requires duplicate channel by block_size * block_size
(means copy channel by block_size * block_size
)
So, the right way is to use eniops.rearrange()
:
result = einops.rearrange(x, 'b c (h p1) (w p2) -> b (p1 p2) h w', p1=block_size, p2=block_size)
You can follow Tresnet Github official sourceode
https://github.com/Alibaba-MIIL/TResNet/blob/master/src/models/tresnet/layers/space_to_depth.py