Reshape tensors of unknown shape with tf.function

Question:

Let’s say that in my function I have to deal with input tensors of shape [4,3,2,1] and [5,4,3,2,1]. I want to reshape them in such a way that the last two dimensions are swapped, e.g., to [4,3,1,2]. In eager mode this is easy, but when I try to wrap my function using @tf.function the following error is thrown:

OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.

The code in question is as follows:

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

    @tf.function
    def my_func():
        reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
        return reshaped

    logging.info(my_func())

It looks like tensorflow does not like the [:-2] notation, but I don’t really know how else I should solve this problem in an elegant and well-readable way.

Asked By: Felix Schön

||

Answers:

Slicing the tensor like that in Graph mode unfortunately does not work, but I think you can use tf.transpose:

import tensorflow as tf

tensor = tf.random.uniform(shape=[4, 3, 2, 1])

@tf.function
def my_func():
  rank = tf.rank(tensor)
  some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
  reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
  return reshaped

print(my_func().shape)
# (4, 3, 1, 2)
Answered By: AloneTogether
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.