TensorFlow Dataset Shuffle Each Epoch

Question:

In the manual on the Dataset class in Tensorflow, it shows how to shuffle the data and how to batch it. However, it’s not apparent how one can shuffle the data each epoch. I’ve tried the below, but the data is given in exactly the same order the second epoch as in the first. Does anybody know how to shuffle between epochs using a Dataset?

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)
for _ in range(4):
    print(sess.run(next_batch))
Asked By: Nathan

||

Answers:

It appears to me that you are using the same next_batch for both cases. So, depedening on what you really want, you may need to recreate next_batch before your second call to sess.run such as shown below, otherwise the data = data.shuffle(12) does not have any effect on the next_batch you created earlier in the code.

n_epochs = 2
batch_size = 3

data = tf.contrib.data.Dataset.range(12)

data = data.repeat(n_epochs)
data = data.batch(batch_size)
next_batch = data.make_one_shot_iterator().get_next()

sess = tf.Session()
for _ in range(4):
    print(sess.run(next_batch))

print("new epoch")
data = data.shuffle(12)

"""See how I recreate next_batch after the data has been shuffled"""
next_batch = data.make_one_shot_iterator().get_next()
for _ in range(4):
    print(sess.run(next_batch))

Please, let me know if this helps. Thanks.

Answered By: emmanuelsa

My environment: Python 3.6, TensorFlow 1.4.

TensorFlow has added Dataset into tf.data.

You should be cautious with the position of data.shuffle. In your code, the epochs of data has been put into the dataset‘s buffer before your shuffle. Here is two usable examples to shuffle dataset.

shuffle all elements

# shuffle all elements
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=buffer_size)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[1 4 5]
[3 0 7]
[6 9 8]
[10  2 11]
epoch 2
[2 0 6]
[1 7 4]
[5 3 8]
[11  9 10]

shuffle between batches, not shuffle in a batch

# shuffle between batches, not shuffle in a batch
import tensorflow as tf

n_epochs = 2
batch_size = 3
buffer_size = 5

dataset = tf.data.Dataset.range(12)
dataset = dataset.batch(batch_size)
dataset = dataset.repeat(n_epochs)
dataset = dataset.shuffle(buffer_size=buffer_size)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

sess = tf.Session()
print("epoch 1")
for _ in range(4):
    print(sess.run(next_batch))
print("epoch 2")
for _ in range(4):
    print(sess.run(next_batch))

OUTPUT:

epoch 1
[0 1 2]
[6 7 8]
[3 4 5]
[6 7 8]
epoch 2
[3 4 5]
[0 1 2]
[ 9 10 11]
[ 9 10 11]
Answered By: William

Here is a simpler solution that does not need to call repeat:

dataset = tf.data.Dataset.range(12)
dataset = dataset.shuffle(buffer_size=dataset.cardinality(), reshuffle_each_iteration=True)
Answered By: Amin
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.