Given a dataset, find if it's batch'ed
Question:
How do I implement function isBatched
so that it tests if the argument dataset
is batched?
import tensorflow as tf
print(tf.__version__)
def isBatched(dataset):
# I guess this is what @yudhiesh means
batch = next(iter(dataset))
return batch.shape.ndims > 0 and batch.shape[0] > 0
tensor1 = tf.range(100)
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
assert isBatched(dataset.batch(10)) == True, "T1 fails"
assert isBatched(dataset.batch(10).map(lambda x: x)) == True, "T2 fails"
assert isBatched(dataset.batch(10).filter(lambda x: True).xxx.yyy.zzz) == True, "T3 fails"
assert isBatched(dataset.repeat()) == False, "T4 fails"
tensor2 = tf.random.uniform([10, 10])
dataset = tf.data.Dataset.from_tensor_slices(tensor2)
assert isBatched(dataset) == False, "T5 fails"
Don’t have to consider .batch().unbatch() cases.
I checked Is there a way to find the batch size for a tf.data.Dataset, which seems to require the last call being .batch(). In my case, .batch can appear anywhere during the call chain.
How to get batch size back from a tensorflow dataset? assumes the first dimention is the batch. It doesn’t work if the original dataset is multi-dimention.
Please show me code because I’m preparing my lecture for my students tomorrow.
Answers:
If you have a dataset:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1)
You can inspect the input dataset and see if it’s a BatchDataset
:
x._input_dataset.__class__.__name__
'BatchDataset'
It is, so it will have a _batch_size
attribute:
x._input_dataset._batch_size
<tf.Tensor: shape=(), dtype=int64, numpy=4>
Maybe the second last operation will not be the batching, so you may need to use _input_dataset
iteratively to find the batch dataset, like so:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1).map(lambda x: x).cache()
x._input_dataset._input_dataset._input_dataset.__class__.__name__
'BatchDataset'
So would the following solution work in all cases?
def labels_from_dataset(dataset):
if not isinstance(dataset, tf.data.Dataset):
raise TypeError('dataset is not a tf.data.Dataset')
input_dataset = dataset._input_dataset
while not hasattr(input_dataset, '_batch_size') and hasattr(input_dataset, '_input_dataset'):
input_dataset = input_dataset._input_dataset
if hasattr(input_dataset, '_batch_size'):
dataset = dataset.unbatch()
y_labels = []
for _, labels in dataset:
y_labels.append(labels.numpy())
return y_labels
How do I implement function isBatched
so that it tests if the argument dataset
is batched?
import tensorflow as tf
print(tf.__version__)
def isBatched(dataset):
# I guess this is what @yudhiesh means
batch = next(iter(dataset))
return batch.shape.ndims > 0 and batch.shape[0] > 0
tensor1 = tf.range(100)
dataset = tf.data.Dataset.from_tensor_slices(tensor1)
assert isBatched(dataset.batch(10)) == True, "T1 fails"
assert isBatched(dataset.batch(10).map(lambda x: x)) == True, "T2 fails"
assert isBatched(dataset.batch(10).filter(lambda x: True).xxx.yyy.zzz) == True, "T3 fails"
assert isBatched(dataset.repeat()) == False, "T4 fails"
tensor2 = tf.random.uniform([10, 10])
dataset = tf.data.Dataset.from_tensor_slices(tensor2)
assert isBatched(dataset) == False, "T5 fails"
Don’t have to consider .batch().unbatch() cases.
I checked Is there a way to find the batch size for a tf.data.Dataset, which seems to require the last call being .batch(). In my case, .batch can appear anywhere during the call chain.
How to get batch size back from a tensorflow dataset? assumes the first dimention is the batch. It doesn’t work if the original dataset is multi-dimention.
Please show me code because I’m preparing my lecture for my students tomorrow.
If you have a dataset:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1)
You can inspect the input dataset and see if it’s a BatchDataset
:
x._input_dataset.__class__.__name__
'BatchDataset'
It is, so it will have a _batch_size
attribute:
x._input_dataset._batch_size
<tf.Tensor: shape=(), dtype=int64, numpy=4>
Maybe the second last operation will not be the batching, so you may need to use _input_dataset
iteratively to find the batch dataset, like so:
import tensorflow as tf
x = tf.data.Dataset.from_tensor_slices(list(range(48))).
batch(4).prefetch(1).map(lambda x: x).cache()
x._input_dataset._input_dataset._input_dataset.__class__.__name__
'BatchDataset'
So would the following solution work in all cases?
def labels_from_dataset(dataset):
if not isinstance(dataset, tf.data.Dataset):
raise TypeError('dataset is not a tf.data.Dataset')
input_dataset = dataset._input_dataset
while not hasattr(input_dataset, '_batch_size') and hasattr(input_dataset, '_input_dataset'):
input_dataset = input_dataset._input_dataset
if hasattr(input_dataset, '_batch_size'):
dataset = dataset.unbatch()
y_labels = []
for _, labels in dataset:
y_labels.append(labels.numpy())
return y_labels