retrieving the next element from tf.data.Dataset in tensorflow 2.0 beta

Question:

Before tensorflow 2.0-beta, to retrieve the first element from tf.data.Dataset, we may use a iterator as shown below:

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])
iterator = train_dataset.make_one_shot_iterator()
with tf.Session() as sess:
    # 1.0 will be printed.
    print (sess.run(iterator.get_next()))

In tensorflow 2.0-beta, it seems that the above one-shot-iterator is now deprecated. To print out the entire elements we may use the following for approach.

#!/usr/bin/python

import tensorflow as tf

train_dataset = tf.data.Dataset.from_tensor_slices([1.0, 2.0, 3.0, 4.0])

for data in train_dataset:
    # 1.0, 2.0, 3.0, and 4.0 will be printed.
    print (data.numpy())

However, if we only want to retrieve exactly one element from tf.data.Dataset, then how can we do with tensorflow 2.0 beta? It seems that next(train_dataset) is not supported. It could be done easily with the old one shot iterator as shown above, but it’s not very obvious with the new for based approach.

Any suggestion is welcomed.

Asked By: chanwcom

||

Answers:

You can .take(1) from the dataset:

for elem in train_dataset.take(1):
  print (elem.numpy())
Answered By: Stewart_R

What works with eager execution enabled (default in TF 2.0) is:

elem = next(iter(train_dataset))

Explanation: Datasets have an __iter__ member function to support the for elem in dataset: approach. This returns an iterator. The Python function iter does just that: Basically calls the __iter__ function. next then returns the first element that iterator produces.

I haven’t found a solution which works for non-eager execution though, as that currently raises RuntimeError: __iter__() is only supported inside of tf.function or when eager execution is enabled.

Answered By: Flamefire

You can also convert the train.Dataset into a numpy iterator then use next().

np_iter = train_dataset.as_numpy_iterator()
print(np_iter.next())

Here: https://www.tensorflow.org/api_docs/python/tf/data/Dataset.

Answered By: Ratna Sambhav