why my tf.tensor_scatter_nd_add can't do the same as torch scatter_add_
Question:
new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)
assumpt new_means.shape=[3, 4], indices.shape=[4, 4] and updates.sahpe=[4, 4].
the above code return err :
Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,4] updates: [4,4].
even when I set the two arr the same shape, it still returns the similar err.
but it works well in pytroch scatter_add.
I don’t know why and how to achieve the same in tensorflow as in pytorch.
Can you help me please?
I try to read the offical explantations and found out there are some confusing requirements within it. How can I get the same effect as catter_add
Answers:
inspired by a similar question and resolve
how to change torch.scatter_add to tensorflow function
I solved my own question by transform the original data into 1-D dimension(To satisfy the requirement of function in tensorflow)
def scatter_add(tensor, indices, updates):
"""
according to some problems with using tf.tensor_scatter_nd_add, we firstly reshape to one-dimension
"""
original_tensor = tensor
indices_add = tf.range(0, indices.shape[-1])
indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
indices = indices * indices.shape[-1]
indices += indices_add
tensor = tf.reshape(tensor, shape=[-1])
indices = tf.reshape(indices, shape=[-1, 1])
updates = tf.reshape(updates, shape=[-1])
scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
scatter = tf.squeeze(scatter)
return scatter
new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)
assumpt new_means.shape=[3, 4], indices.shape=[4, 4] and updates.sahpe=[4, 4].
the above code return err :
Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,4] updates: [4,4].
even when I set the two arr the same shape, it still returns the similar err.
but it works well in pytroch scatter_add.
I don’t know why and how to achieve the same in tensorflow as in pytorch.
Can you help me please?
I try to read the offical explantations and found out there are some confusing requirements within it. How can I get the same effect as catter_add
inspired by a similar question and resolve
how to change torch.scatter_add to tensorflow function
I solved my own question by transform the original data into 1-D dimension(To satisfy the requirement of function in tensorflow)
def scatter_add(tensor, indices, updates):
"""
according to some problems with using tf.tensor_scatter_nd_add, we firstly reshape to one-dimension
"""
original_tensor = tensor
indices_add = tf.range(0, indices.shape[-1])
indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
indices = indices * indices.shape[-1]
indices += indices_add
tensor = tf.reshape(tensor, shape=[-1])
indices = tf.reshape(indices, shape=[-1, 1])
updates = tf.reshape(updates, shape=[-1])
scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
scatter = tf.squeeze(scatter)
return scatter