Understanding tf.extract_image_patches for extracting patches from an image

Question:

I found the following method tf.extract_image_patches in tensorflow API, but I am not clear about its functionality.

Say the batch_size = 1, and an image is of size 225x225x3, and we want to extract patches of size 32x32.

How exactly does this function behave? Specifically, the documentation mentions the dimension of the output tensor to be [batch, out_rows, out_cols, ksize_rows * ksize_cols * depth] , but what out_rows and out_cols are is not mentioned.

Ideally, given an input image tensor of size 1x225x225x3 (where 1 is the batch size), I want to be able to get Kx32x32x3 as output, where K is the total number of patches and 32x32x3 is the dimension of each patch. Is there something in tensorflow that already achieves this?

Asked By: deeptigp

||

Answers:

Here is how the method works:

  • ksizes is used to decide the dimensions of each patch, or in other words, how many pixels each patch should contain.

  • strides denotes the length of the gap between the start of one patch and the start of the next consecutive patch within the original image.

  • rates is a number that essentially means our patch should jump by rates pixels in the original image for each consecutive pixel that ends up in our patch. (The example below helps illustrate this.)

  • padding is either "VALID", which means every patch must be fully contained in the image, or "SAME", which means patches are allowed to be incomplete (the remaining pixels will be filled in with zeroes).

Here is some sample code with output to help demonstrate how it works:

import tensorflow as tf

n = 10
# images is a 1 x 10 x 10 x 1 array that contains the numbers 1 through 100 in order
images = [[[[x * n + y + 1] for y in range(n)] for x in range(n)]]

# We generate four outputs as follows:
# 1. 3x3 patches with stride length 5
# 2. Same as above, but the rate is increased to 2
# 3. 4x4 patches with stride length 7; only one patch should be generated
# 4. Same as above, but with padding set to 'SAME'
with tf.Session() as sess:
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), 'nn'
  print tf.extract_image_patches(images=images, ksizes=[1, 3, 3, 1], strides=[1, 5, 5, 1], rates=[1, 2, 2, 1], padding='VALID').eval(), 'nn'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='VALID').eval(), 'nn'
  print tf.extract_image_patches(images=images, ksizes=[1, 4, 4, 1], strides=[1, 7, 7, 1], rates=[1, 1, 1, 1], padding='SAME').eval()

Output:

[[[[ 1  2  3 11 12 13 21 22 23]
   [ 6  7  8 16 17 18 26 27 28]]

  [[51 52 53 61 62 63 71 72 73]
   [56 57 58 66 67 68 76 77 78]]]]


[[[[  1   3   5  21  23  25  41  43  45]
   [  6   8  10  26  28  30  46  48  50]]

  [[ 51  53  55  71  73  75  91  93  95]
   [ 56  58  60  76  78  80  96  98 100]]]]


[[[[ 1  2  3  4 11 12 13 14 21 22 23 24 31 32 33 34]]]]


[[[[  1   2   3   4  11  12  13  14  21  22  23  24  31  32  33  34]
   [  8   9  10   0  18  19  20   0  28  29  30   0  38  39  40   0]]

  [[ 71  72  73  74  81  82  83  84  91  92  93  94   0   0   0   0]
   [ 78  79  80   0  88  89  90   0  98  99 100   0   0   0   0   0]]]]

So, for example, our first result looks like the following:

 *  *  *  4  5  *  *  *  9 10 
 *  *  * 14 15  *  *  * 19 20 
 *  *  * 24 25  *  *  * 29 30 
31 32 33 34 35 36 37 38 39 40 
41 42 43 44 45 46 47 48 49 50 
 *  *  * 54 55  *  *  * 59 60 
 *  *  * 64 65  *  *  * 69 70 
 *  *  * 74 75  *  *  * 79 80 
81 82 83 84 85 86 87 88 89 90 
91 92 93 94 95 96 97 98 99 100 

As you can see, we have 2 rows and 2 columns worth of patches, which are what out_rows and out_cols are.

Answered By: Neal

To expand on Neal’s detailed answer, there are a lot of subtleties with zero padding when using “SAME”, since extract_image_patches tries to center the patches in the image if possible. Depending on the stride, there may be padding on the top and left, or not, and the first patch doesn’t necessarily start in the upper left.

For example, extending the previous example:

print tf.extract_image_patches(images, [1, 3, 3, 1], [1, n, n, 1], [1, 1, 1, 1], 'SAME').eval()[0]

With a stride of n=1, the image is padded with zeros all around and the first patch starts with padding. Other strides pad the image only on the right and bottom, or not at all.
With a stride of n=10, the single patch starts at element 34 (in the middle of the image).

tf.extract_image_patches is implemented by the eigen library as described in this answer. You can study that code to see exactly how patch positions and padding are computed.

Answered By: Ken Shirriff

Introduction

Here I would like to present a rather simple demonstration to use the tf.image.extract_patches with images itself. I have found a rather small amount of implementation of the method with actual images with the proper visualizations, so here it is.

The image we will use is of size (256, 256, 3). The patches we will be extracting will be shaped (128, 128, 3). This means that we will retrieve 4 tiles from the image.

Data used

I will be using the flowers dataset. Due to the fact that this answer needs a little data pipeline, I will be linking my kaggle kernel here which talks about consuming the dataset with tf.data.Dataset API.

After we are through we go through the following code snippets.

images, _ = next(iter(train_ds.take(1)))

image = images[0]
plt.imshow(image.numpy().astype("uint8"))

The flower

Here we are taking one image from the batch of images and visualizing it as is.

image = tf.expand_dims(image,0) # To create the batch information
patches = tf.image.extract_patches(images=image,
                                   sizes=[1, 128, 128, 1],
                                   strides=[1, 128, 128, 1],
                                   rates=[1, 1, 1, 1],
                                   padding='VALID')

With this snippet, we are extracting patches of size (128,128) from the image of size (256,256). This directly translates to the fact that I would want the images to be split into 4 tiles.

Visualization

plt.figure(figsize=(10, 10))
for imgs in patches:
    count = 0
    for r in range(2):
        for c in range(2):
            ax = plt.subplot(2, 2, count+1)
            plt.imshow(tf.reshape(imgs[r,c],shape=(128,128,3)).numpy().astype("uint8"))
            count += 1

Splits of the flower