How to fit Keras ImageDataGenerator for large data sets using batches

Question:

I want to use the Keras ImageDataGenerator for data augmentation.
To do so, I have to call the .fit() function on the instantiated ImageDataGenerator object using my training data as parameter as shown below.

image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
image_datagen.fit(X_train, augment=True)
train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)

However, my training data set is too large to fit into memory when loaded up at once.
Consequently, I would like to fit the generator in several steps using subsets of my training data.

Is there a way to do this?

One potential solution that came to my mind is to load up batches of my training data using a custom generator function and fitting the image generator multiple times in a loop. However, I am not sure whether the fit function of ImageDataGenerator can be used in this way as it might reset on each fitting approach.

As an example of how it might work:

def custom_train_generator():
    # Code loading training data subsets X_batch
    yield X_batch


image_datagen = ImageDataGenerator(featurewise_center=True, rotation_range=90)
gen = custom_train_generator()

for batch in gen:
    image_datagen.fit(batch, augment=True)

train_generator = image_datagen.flow_from_directory('data/images')
model.fit_generator(train_generator, steps_per_epoch=2000, epochs=50)
Asked By: C.S.

||

Answers:

NEWER TF VERSIONS (>=2.5):

ImageDataGenerator() has been deprecated in favour of :

tf.keras.utils.image_dataset_from_directory

An example usage from the documentation:

  tf.keras.utils.image_dataset_from_directory(
    directory,
    labels='inferred',
    label_mode='int',
    class_names=None,
    color_mode='rgb',
    batch_size=32,
    image_size=(256, 256),
    shuffle=True,
    seed=None,
    validation_split=None,
    subset=None,
    interpolation='bilinear',
    follow_links=False,
    crop_to_aspect_ratio=False,
    **kwargs
)

OLDER TF VERSIONS (<2.5)

ImageDataGenerator() provides you with the possibility of loading the data into batches; You can actually use in your fit_generator() method the parameter batch_size, which works with ImageDataGenerator(); there is no need (only for good practice if you want) to write a generator from scratch.

IMPORTANT NOTE:

Starting from TensorFlow 2.1, .fit_generator() has been deprecated and you should use .fit()

Example taken from Keras official documentation:

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

# compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied)
datagen.fit(x_train)

# TF <= 2.0
# fits the model on batches with real-time data augmentation:
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
                    steps_per_epoch=len(x_train) // 32, epochs=epochs)

#TF >= 2.1
model.fit(datagen.flow(x_train, y_train, batch_size=32),
         steps_per_epoch=len(x_train) // 32, epochs=epochs)

I would suggest reading this excellent article about ImageDataGenenerator and Augmentation: https://machinelearningmastery.com/how-to-configure-image-data-augmentation-when-training-deep-learning-neural-networks/

The solution to your problem lies in this line of code(either simple flow or flow_from_directory):

# prepare iterator
it = datagen.flow(samples, batch_size=1)

For creating your own DataGenerator, one should have a look at this link(for a starting point): https://stanford.edu/~shervine/blog/keras-how-to-generate-data-on-the-fly

IMPORTANT NOTE (2):

If you use Keras from Tensorflow (Keras inside Tensorflow), then for both the code presented and the tutorials you consult, ensure that you replace the import/neural network creation snippets:

from keras.x.y.z import A

WITH

from tensorflow.keras.x.y.z import A
Answered By: Timbus Calin
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.