Is it possible to split a tensorflow dataset into train, validation AND test datasets when using image_dataset_from_directory?

Question:

I am using tf.keras.utils.image_dataset_from_directory to load a dataset of 4575 images. While this function allows to split the data into two subsets (with the validation_split parameter), I want to split it into training, testing, and validation subsets.

I have tried using dataset.skip() and dataset.take() to further split one of the resulting subsets, but these functions return a SkipDataset and a TakeDataset respectively (by the way, contrary to the documentation, where it is claimed that these functions return a Dataset). This leads to problems when fitting the model – the metrics calculated on validation sets (val_loss, val_accuracy) disappear from model history.

So, my question is: is there a way to split a Dataset into three subsets for training, validation and testing, so that all three subsets are also Dataset objects?

Code used to load the data

def load_data_tf(data_path: str, img_shape=(256,256), batch_size: int=8):
    train_ds = tf.keras.utils.image_dataset_from_directory(
        data_path,
        validation_split=0.2,
        subset="training",
        label_mode='categorical',
        seed=123,
        image_size=img_shape,
        batch_size=batch_size)
    val_ds = tf.keras.utils.image_dataset_from_directory(
        data_path,
        validation_split=0.3,
        subset="validation",
        label_mode='categorical',
        seed=123,
        image_size=img_shape,
        batch_size=batch_size)
    return train_ds, val_ds

train_dataset, test_val_ds = load_data_tf('data_folder', img_shape = (256,256), batch_size=8)
test_dataset = test_val_ds.take(686)
val_dataset = test_val_ds.skip(686)

Model compilation and fitting

model.compile(optimizer='sgd',
              loss=tf.keras.losses.CategoricalCrossentropy(from_logits=False),
              metrics=['accuracy'])
history = model.fit(train_dataset, epochs=50, validation_data=val_dataset, verbose=1)

When using a normal Dataset, val_accuracy and val_loss are present in the history of the model:

Expected behaviour: when using a Dataset, validation metrics are calculated

But when using a SkipDataset, they are not:

Using the SkipDataset produced by test_val_ds.take() leads to validation metrics disappearing from model history

val_accuracy and val_loss are not present in history keys when using a SkipDataset or a TakeDataset

Asked By: andrii kliachkin

||

Answers:

The issue is that you are not taking and skipping samples when you do test_val_ds.take(686) and test_val_ds.skip(686), but actually batches. Try running print(val_dataset.cardinality()) and you will see how many batches you really have reserved for validation. I am guessing val_dataset is empty, because you do not have 686 batches for validation. Here is a working example:

import tensorflow as tf
import pathlib

dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos', origin=dataset_url, untar=True)
data_dir = pathlib.Path(data_dir)

batch_size = 32

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="training",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.2,
  subset="validation",
  seed=123,
  image_size=(180, 180),
  batch_size=batch_size)

test_dataset = val_ds.take(5)
val_ds = val_ds.skip(5)

print('Batches for testing -->', test_dataset.cardinality())
print('Batches for validating -->', val_ds.cardinality())

model = tf.keras.Sequential([
  tf.keras.layers.Rescaling(1./255, input_shape=(180, 180, 3)),
  tf.keras.layers.Conv2D(16, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(64, 3, padding='same', activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(5)
])

model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

epochs=1
history = model.fit(
  train_ds,
  validation_data=val_ds,
  epochs=1
)
Found 3670 files belonging to 5 classes.
Using 2936 files for training.
Found 3670 files belonging to 5 classes.
Using 734 files for validation.
Batches for testing --> tf.Tensor(5, shape=(), dtype=int64)
Batches for validating --> tf.Tensor(18, shape=(), dtype=int64)
92/92 [==============================] - 96s 1s/step - loss: 1.3516 - accuracy: 0.4489 - val_loss: 1.1332 - val_accuracy: 0.5645

In this example, with a batch_size of 32, you can clearly see that the validation set reserved 23 batches. Afterwards, 5 batches were given to the test set and 18 batches remained for the validation set.

Answered By: AloneTogether

I can’t comment, so have to answer to JeffreyShran about how we can be sure about take and skip taking the same pictures in that block. Here is the check code:

dataset = tf.data.Dataset.range(10)
take = int(len(dataset)/2)

test = dataset.take(take)
print('test:', list(test.as_numpy_iterator()))
dataset = dataset.skip(take)
print('valid:', list(dataset.as_numpy_iterator()))

We get:

test: [0, 1, 2, 3, 4]
valid: [5, 6, 7, 8, 9]

I’m a newcomer, so my apologies if I’m writing not in the appropriate place. But I think that consideration above must have been proved.

Answered By: Red Applicata