Split tf.data.Dataset into images and labels, but preserving the order


I am working on a classical Cats-vs-Dogs machine learning project and have the following problem:

I have a tf.data.Dataset containing images and respective labels (0 and 1). I now want to predict the labels with a pretrained model and compare the predicted labels to the true labels. The model only accepts the images as input, so I have to split images and labels.

I imported the data as follows:

train, test = tf.keras.utils.image_dataset_from_directory(
    image_size = (IMG_SIZE, IMG_SIZE),
    crop_to_aspect_ratio = True,
    label_mode = 'binary',
    batch_size = None,
    validation_split = VAL_SPLIT,
    subset = 'both',
    seed = 1,
    shuffle = True

Then I tried to split the training images from the training labels:

def prepare_data_for_network(train):
    X_net = train.map(lambda image, label: image)
    X_net = X_net.batch(1)
    y_net = train.map(lambda image, label: label)
    y_net = np.squeeze(np.array([y for y in y_net]))
    return X_net, y_net

Then I predicted the labels with the model:

y_pred = model.predict(X_net)

The latter returns a numpy array with the same shape as y_net. The problem now is: The order is totally random, e.g. y_pred[0] is referring to a totally different image than y_net[0]. How can I achieve a order-preserving split?

What I tried: I tried to set shuffle=False when importing the data, then shuffle the data once afterwards and create train and test dataset with the methods "take" and "skip". Unfortunately, this creates not Datasets but "TakeDatasets" which cause further problems. I also tried putting everything into a Numpy array and then shuffle and split it; this works well if I set the resolution very low, but for higher resolutions, it causes my kernel to die. I now spent four days with only getting the data into the right format, so I’m a bit frustrated and would really appreciate any help :/

Asked By: Paul W.



I found the problem! The shuffle dataset with "seed = 1" has always the same order; the problem was that I used

X_net = X_net.batch(1)

to add an "empty" dimension, but this seems to distort the order. It can be done in an order-preserving way by

X_net = X_net.map(lambda image: tf.expand_dims(image, 0))
Answered By: Paul W.