Reshape tensors of unknown shape with tf.function
Question:
Let’s say that in my function I have to deal with input tensors of shape [4,3,2,1] and [5,4,3,2,1]. I want to reshape them in such a way that the last two dimensions are swapped, e.g., to [4,3,1,2]. In eager mode this is easy, but when I try to wrap my function using @tf.function
the following error is thrown:
OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
The code in question is as follows:
tensor = tf.random.uniform(shape=[4, 3, 2, 1])
@tf.function
def my_func():
reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
return reshaped
logging.info(my_func())
It looks like tensorflow does not like the [:-2]
notation, but I don’t really know how else I should solve this problem in an elegant and well-readable way.
Answers:
Slicing the tensor like that in Graph
mode unfortunately does not work, but I think you can use tf.transpose
:
import tensorflow as tf
tensor = tf.random.uniform(shape=[4, 3, 2, 1])
@tf.function
def my_func():
rank = tf.rank(tensor)
some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
return reshaped
print(my_func().shape)
# (4, 3, 1, 2)
Let’s say that in my function I have to deal with input tensors of shape [4,3,2,1] and [5,4,3,2,1]. I want to reshape them in such a way that the last two dimensions are swapped, e.g., to [4,3,1,2]. In eager mode this is easy, but when I try to wrap my function using @tf.function
the following error is thrown:
OperatorNotAllowedInGraphError: Iterating over a symbolic tf.Tensor is not allowed: AutoGraph did convert this function. This might indicate you are trying to use an unsupported feature.
The code in question is as follows:
tensor = tf.random.uniform(shape=[4, 3, 2, 1])
@tf.function
def my_func():
reshaped = tf.reshape(tensor, shape=[*tf.shape(tensor)[:-2], tf.shape(tensor)[-1], tf.shape(tensor)[-2]])
return reshaped
logging.info(my_func())
It looks like tensorflow does not like the [:-2]
notation, but I don’t really know how else I should solve this problem in an elegant and well-readable way.
Slicing the tensor like that in Graph
mode unfortunately does not work, but I think you can use tf.transpose
:
import tensorflow as tf
tensor = tf.random.uniform(shape=[4, 3, 2, 1])
@tf.function
def my_func():
rank = tf.rank(tensor)
some_magic = tf.concat([tf.zeros((rank - 2,), dtype=tf.int32), [1, -1]], axis=-1)
reshaped = tf.transpose(tensor, perm = tf.range(rank) + some_magic)
return reshaped
print(my_func().shape)
# (4, 3, 1, 2)