How to remove list of elements from a Tensorflow tensor
Question:
For the following tensor:
<tf.Tensor: shape=(2, 10, 6), dtype=int64, numpy=
array([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]], dtype=int64)>)
I want to remove the last list [ 3, 2, 22, 855, 450, 4]
in the tensor. I tried with tf.unstack
but it didn’t work.
Answers:
you can try below to remove the last list from the tensor:
sliced_tensor = tf.slice(tensor, [0, 0, 0], [2, 9, 6])
try this instead
new_tensor = tf.slice(tensor, [0,0,0], [2,9,6], [1,1,1])
You could also simply use tf.ragged.boolean_mask
to exclude the row you do not want:
import tensorflow as tf
x = tf.constant([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]])
x_shape = tf.shape(x)
remove = tf.constant([3, 2, 22, 855, 450, 4])
mask = tf.reduce_all(tf.equal(x, remove), axis=-1)
x = tf.ragged.boolean_mask(x, ~mask)
print(x)
<tf.RaggedTensor [[[3, 16, 43, 10, 7, 431],
[3, 2, 6, 5, 7, 2],
[3, 37, 5, 7, 2, 12],
[3, 2, 11, 5, 7, 2],
[3, 2, 6, 18, 14, 195],
[3, 2, 6, 5, 7, 195],
[3, 2, 6, 5, 7, 9],
[3, 2, 11, 7, 2, 12],
[3, 16, 52, 92, 177, 923],
[3, 9, 43, 10, 7, 9]] , [[3, 2, 22, 495, 230, 4],
[3, 2, 22, 5, 102, 122],
[3, 2, 22, 5, 102, 230],
[3, 2, 22, 5, 70, 908],
[3, 2, 22, 5, 70, 450],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 230],
[3, 2, 22, 70, 34, 470]]]>
For the following tensor:
<tf.Tensor: shape=(2, 10, 6), dtype=int64, numpy=
array([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]], dtype=int64)>)
I want to remove the last list [ 3, 2, 22, 855, 450, 4]
in the tensor. I tried with tf.unstack
but it didn’t work.
you can try below to remove the last list from the tensor:
sliced_tensor = tf.slice(tensor, [0, 0, 0], [2, 9, 6])
try this instead
new_tensor = tf.slice(tensor, [0,0,0], [2,9,6], [1,1,1])
You could also simply use tf.ragged.boolean_mask
to exclude the row you do not want:
import tensorflow as tf
x = tf.constant([[[ 3, 16, 43, 10, 7, 431],
[ 3, 2, 6, 5, 7, 2],
[ 3, 37, 5, 7, 2, 12],
[ 3, 2, 11, 5, 7, 2],
[ 3, 2, 6, 18, 14, 195],
[ 3, 2, 6, 5, 7, 195],
[ 3, 2, 6, 5, 7, 9],
[ 3, 2, 11, 7, 2, 12],
[ 3, 16, 52, 92, 177, 923],
[ 3, 9, 43, 10, 7, 9]],
[[ 3, 2, 22, 495, 230, 4],
[ 3, 2, 22, 5, 102, 122],
[ 3, 2, 22, 5, 102, 230],
[ 3, 2, 22, 5, 70, 908],
[ 3, 2, 22, 5, 70, 450],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 122],
[ 3, 2, 22, 5, 70, 230],
[ 3, 2, 22, 70, 34, 470],
[ 3, 2, 22, 855, 450, 4]]])
x_shape = tf.shape(x)
remove = tf.constant([3, 2, 22, 855, 450, 4])
mask = tf.reduce_all(tf.equal(x, remove), axis=-1)
x = tf.ragged.boolean_mask(x, ~mask)
print(x)
<tf.RaggedTensor [[[3, 16, 43, 10, 7, 431],
[3, 2, 6, 5, 7, 2],
[3, 37, 5, 7, 2, 12],
[3, 2, 11, 5, 7, 2],
[3, 2, 6, 18, 14, 195],
[3, 2, 6, 5, 7, 195],
[3, 2, 6, 5, 7, 9],
[3, 2, 11, 7, 2, 12],
[3, 16, 52, 92, 177, 923],
[3, 9, 43, 10, 7, 9]] , [[3, 2, 22, 495, 230, 4],
[3, 2, 22, 5, 102, 122],
[3, 2, 22, 5, 102, 230],
[3, 2, 22, 5, 70, 908],
[3, 2, 22, 5, 70, 450],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 122],
[3, 2, 22, 5, 70, 230],
[3, 2, 22, 70, 34, 470]]]>