How to make a batch of image patches, not separate patches

Question:

I have written some lines of code to extract 5 bounding boxes/patches of a single image. When I run this code and print its output shape it’s something like this (5, 256), five patches each with a vector of 256. The problem is the separate extraction of patches from a single image, when I feed 5000+ images to this code, it generates 5000*5 patches mixed with each other. In this way, it loses the patch/image relationship. I want to change this code somehow to generate an output with batch information like this (1, 5, 256). In this way, each batch will represent an image.

def create_vision_encoder(
num_projection_layers, projection_dims, dropout_rate, trainable=False
):

  xception = keras.applications.Xception(
    include_top=False, weights="imagenet", pooling="avg"
)
  for layer in xception.layers:
    layer.trainable = trainable

  inputs = layers.Input(shape=(299, 299, 3), name="image_input")

  NUM_BOXES = 5
  CHANNELS = 3
  CROP_SIZE = (200, 200)

  boxes = tf.random.uniform(shape=(NUM_BOXES, 4))
  box_indices = tf.random.uniform(shape=(NUM_BOXES,), minval=0,
  maxval=BATCH_SIZE, dtype=tf.int32)
  output = tf.image.crop_and_resize(inputs, boxes, box_indices, CROP_SIZE)

  xception_input = tf.keras.applications.xception.preprocess_input(output)
  embeddings = xception(xception_input)

  outputs = project_embeddings(
    embeddings, num_projection_layers, projection_dims, dropout_rate
)

  return keras.Model(inputs, outputs, name="vision_encoder")
Asked By: Jacob

||

Answers:

You can create a ImagePatchesAndEmbedding layer that will stack the captured bounding boxes and apply xception:

class ImagePatchesAndEmbedding(keras.layers.Layer):
    def __init__(self, crop_size, num_boxes=5, minval=0, maxval=1):
        super(ImagePatchesAndEmbedding, self).__init__()
        self.crop_size = crop_size
        self.boxes = tf.random.uniform(shape=(num_boxes, 4))
        self.box_indices = tf.random.uniform(shape=(num_boxes,), minval=0,
                                             maxval=1, dtype=tf.int32)
        self.preprocess = tf.keras.applications.xception.preprocess_input
        
    def call(self, inputs):
        patches = tf.map_fn(lambda img:tf.image.crop_and_resize(img[None,...],
                        self.boxes, self.box_indices, self.crop_size), inputs)   
        embeddings = tf.map_fn(lambda patch: xception(self.preprocess(patch)), patches)
        return embeddings 

Model,

inputs = layers.Input(shape=(299, 299, 3), name="image_input")
NUM_BOXES = 5
CHANNELS = 3
CROP_SIZE = (200, 200)
BATCH_SIZE = 3
output = ImagePatchesAndEmbedding(CROP_SIZE, num_boxes=5, maxval=BATCH_SIZE)(inputs)
model = keras.Model(inputs, output)

Call model,

model(tf.random.normal(shape=(BATCH_SIZE, 299, 299, 3))).shape
#[3, 5, 2048]