Given a dataset, find if it's batch'ed

Question:

How do I implement function isBatched so that it tests if the argument dataset is batched?

import tensorflow as tf

print(tf.__version__)    

def isBatched(dataset):
    # I guess this is what @yudhiesh means
    batch = next(iter(dataset))
    return batch.shape.ndims > 0 and batch.shape[0] > 0

tensor1 = tf.range(100)
dataset = tf.data.Dataset.from_tensor_slices(tensor1)

assert isBatched(dataset.batch(10)) == True, "T1 fails"
assert isBatched(dataset.batch(10).map(lambda x: x)) == True, "T2 fails"
assert isBatched(dataset.batch(10).filter(lambda x: True).xxx.yyy.zzz) == True, "T3 fails"
assert isBatched(dataset.repeat()) == False, "T4 fails"

tensor2 = tf.random.uniform([10, 10])
dataset = tf.data.Dataset.from_tensor_slices(tensor2)
assert isBatched(dataset) == False, "T5 fails"

Don’t have to consider .batch().unbatch() cases.

I checked Is there a way to find the batch size for a tf.data.Dataset, which seems to require the last call being .batch(). In my case, .batch can appear anywhere during the call chain.

How to get batch size back from a tensorflow dataset? assumes the first dimention is the batch. It doesn’t work if the original dataset is multi-dimention.

Please show me code because I’m preparing my lecture for my students tomorrow.

Asked By: Gqqnbig

||

Answers:

If you have a dataset:

import tensorflow as tf

x = tf.data.Dataset.from_tensor_slices(list(range(48))).
    batch(4).prefetch(1)

You can inspect the input dataset and see if it’s a BatchDataset:

x._input_dataset.__class__.__name__
'BatchDataset'

It is, so it will have a _batch_size attribute:

x._input_dataset._batch_size
<tf.Tensor: shape=(), dtype=int64, numpy=4>

Maybe the second last operation will not be the batching, so you may need to use _input_dataset iteratively to find the batch dataset, like so:

import tensorflow as tf

x = tf.data.Dataset.from_tensor_slices(list(range(48))).
    batch(4).prefetch(1).map(lambda x: x).cache()

x._input_dataset._input_dataset._input_dataset.__class__.__name__
'BatchDataset'
Answered By: Nicolas Gervais

So would the following solution work in all cases?

def labels_from_dataset(dataset):
    if not isinstance(dataset, tf.data.Dataset):
      raise TypeError('dataset is not a tf.data.Dataset')

    input_dataset = dataset._input_dataset
    while not hasattr(input_dataset, '_batch_size') and hasattr(input_dataset, '_input_dataset'):
      input_dataset = input_dataset._input_dataset

    if hasattr(input_dataset, '_batch_size'):
      dataset = dataset.unbatch()

    y_labels = []
    for _, labels in dataset:
        y_labels.append(labels.numpy())

    return y_labels
Answered By: ronaldmathies
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.