Conditional assignment of tensor values in TensorFlow

Question:

I want to replicate the following `numpy` code in `tensorflow`. For example, I want to assign a `0` to all tensor indices that previously had a value of `1`.

``````a = np.array([1, 2, 3, 1])
a[a==1] = 0

# a should be [0, 2, 3, 0]
``````

If I write similar code in `tensorflow` I get the following error.

``````TypeError: 'Tensor' object does not support item assignment
``````

The condition in the square brackets should be arbitrary as in `a[a<1] = 0`.

Is there a way to realize this “conditional assignment” (for lack of a better name) in `tensorflow`?

Comparison operators such as greater than are available within TensorFlow API.

However, there is nothing equivalent to the concise NumPy syntax when it comes to manipulating the tensors directly. You have to make use of individual `comparison`, `where` and `assign` operators to perform the same action.

Equivalent code to your NumPy example is this:

``````import tensorflow as tf

a = tf.Variable( [1,2,3,1] )
start_op = tf.global_variables_initializer()
comparison = tf.equal( a, tf.constant( 1 ) )
conditional_assignment_op = a.assign( tf.where (comparison, tf.zeros_like(a), a) )

with tf.Session() as session:
# Equivalent to: a = np.array( [1, 2, 3, 1] )
session.run( start_op )
print( a.eval() )
# Equivalent to: a[a==1] = 0
session.run( conditional_assignment_op )
print( a.eval() )

# Output is:
# [1 2 3 1]
# [0 2 3 0]
``````

The print statements are of course optional, they are just there to demonstrate the code is performing correctly.

I’m also just starting to use tensorflow
Maybe some one will fill my approach more intuitive

``````import tensorflow as tf

conditionVal = 1
init_a = tf.constant([1, 2, 3, 1], dtype=tf.int32, name='init_a')
a = tf.Variable(init_a, dtype=tf.int32, name='a')
target = tf.fill(a.get_shape(), conditionVal, name='target')

init = tf.initialize_all_variables()
condition = tf.not_equal(a, target)
defaultValues = tf.zeros(a.get_shape(), dtype=a.dtype)
calculate = tf.select(condition, a, defaultValues)

with tf.Session() as session:
session.run(init)
session.run(calculate)
print(calculate.eval())
``````

main trouble is that it is difficult to implement “custom logic”. if you could not explain your logic within linear math terms you need to write “custom op” library for tensorflow (more details here)

``````a = np.array([1, 2, 3, 1])
tf.where(tf.equal(a, 1), 0, a)
``````

returns

``````<tf.Tensor: shape=(4,), dtype=int32, numpy=array([0, 2, 3, 0], dtype=int32)>
``````
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.