Dataset.batch doesn't work as expected with a zipped dataset

Question:

I have a dataset like this:

a = tf.data.Dataset.range(1, 16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b))
list(zipped.as_numpy_iterator())

# output: 
[(0, 16),
 (1, 17),
 (2, 18),
 (3, 19),
 (4, 20),
 (5, 21),
 (6, 22),
 (7, 23),
 (8, 24),
 (9, 25),
 (10, 26),
 (11, 27),
 (12, 28),
 (13, 29),
 (14, 30),
 (15, 31)]

When I apply batch(4) to it, the expected result is an array of batches, where each batch contains four tuples:

[[(0, 16), (1, 17), (2, 18), (3, 19)],
 [(4, 20), (5, 21), (6, 22), (7, 23)],
 [(9, 24), (10, 25), (10, 26), (11, 27)],
 [(12, 28), (13, 29), (14, 30), (15, 31)]]

But this is what I receive instead:

batched = zipped.batch(4)
list(batched.as_numpy_iterator())

# Output:
[(array([0, 1, 2, 3]), array([16, 17, 18, 19])), 
 (array([4, 5, 6, 7]), array([20, 21, 22, 23])), 
 (array([ 8,  9, 10, 11]), array([24, 25, 26, 27])), 
 (array([12, 13, 14, 15]), array([28, 29, 30, 31]))]

I’m following this tutorial, he does the same steps but gets the correct output somehow.


Update: according to the documentation this is the intended behavior:

The components of the resulting element will have an additional outer dimension, which will be batch_size

But it doesn’t make any sense. To my understanding, dataset is a list of pieces of data. It doesn’t matter the shape of those pieces of data, when we are batching it we are combining the elements [whatever their shape is] into batches, therefore it should always insert the new dimention to the second position ((length, a, b, c) -> (length', batch_size, a, b, c)).

So my questions are: I wonder what is the purpose of batch() being implemented this way? And what is the alternative that does what I described?

Asked By: splaytreez

||

Answers:

One thing you can try doing is something like this:

import tensorflow as tf

a = tf.data.Dataset.range(16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.transpose([x, y]))

list(zipped.as_numpy_iterator())
[array([[ 0, 16],
        [ 1, 17],
        [ 2, 18],
        [ 3, 19]]), 
 array([[ 4, 20],
        [ 5, 21],
        [ 6, 22],
        [ 7, 23]]), 
 array([[ 8, 24],
        [ 9, 25],
        [10, 26],
        [11, 27]]), 
 array([[12, 28],
        [13, 29],
        [14, 30],
        [15, 31]])]

but they are still not tuples. Or:

zipped = tf.data.Dataset.zip((a, b)).batch(4).map(lambda x, y: tf.unstack(tf.transpose([x, y]), num = 4))
[(array([ 0, 16]), array([ 1, 17]), array([ 2, 18]), array([ 3, 19])), (array([ 4, 20]), array([ 5, 21]), array([ 6, 22]), array([ 7, 23])), (array([ 8, 24]), array([ 9, 25]), array([10, 26]), array([11, 27])), (array([12, 28]), array([13, 29]), array([14, 30]), array([15, 31]))]
Answered By: AloneTogether

You can use multiple batch.

a = tf.data.Dataset.range(16)
b = tf.data.Dataset.range(16, 32)
zipped = tf.data.Dataset.zip((a, b))
batched = zipped.batch(1).batch(4).map(lambda x, y: tf.concat([x, y], 1))
list(batched.as_numpy_iterator())
# [array([[ 0, 16],
#         [ 1, 17],
#         [ 2, 18],
#         [ 3, 19]]),
#  array([[ 4, 20],
#         [ 5, 21],
#         [ 6, 22],
#         [ 7, 23]]),
#  array([[ 8, 24],
#         [ 9, 25],
#         [10, 26],
#         [11, 27]]),
#  array([[12, 28],
#         [13, 29],
#         [14, 30],
#         [15, 31]])]

For converting to a 2D list and each item be a tuple:

result = [list(map(tuple, item)) for item in batched.as_numpy_iterator()]
print(result)
# [
#     [(0, 16), (1, 17), (2, 18), (3, 19)], 
#     [(4, 20), (5, 21), (6, 22), (7, 23)], 
#     [(8, 24), (9, 25), (10, 26), (11, 27)], 
#     [(12, 28), (13, 29), (14, 30), (15, 31)]
# ]

Explanation:

>>> list(zipped.batch(1).as_numpy_iterator())
[(array([0]), array([16])),
 (array([1]), array([17])),
 (array([2]), array([18])),
 (array([3]), array([19])),
 ...
 (array([12]), array([28])),
 (array([13]), array([29])),
 (array([14]), array([30])),
 (array([15]), array([31]))]

# now we need to get '.batch(4)'
>>> list(zipped.batch(1).batch(4).as_numpy_iterator())
[(array([[0],
         [1],
         [2],
         [3]]),
  array([[16],
         [17],
         [18],
         [19]])),
...
 (array([[12],
         [13],
         [14],
         [15]]),
  array([[28],
         [29],
         [30],
         [31]]))]
 
# tf.concat each batch with axis=1
>>> zipped.batch(1).batch(4).map(lambda x, y: tf.concat([x, y], 1))

[array([[ 0, 16],
        [ 1, 17],
        [ 2, 18],
        [ 3, 19]]),
 ...
 array([[12, 28],
        [13, 29],
        [14, 30],
        [15, 31]])]
Answered By: I'mahdi