why my tf.tensor_scatter_nd_add can't do the same as torch scatter_add_

Question:

new_means = tf.tensor_scatter_nd_add(new_means, indices=repeat(buckets, "n -> n d", d=dim), updates=samples)

assumpt new_means.shape=[3, 4], indices.shape=[4, 4] and updates.sahpe=[4, 4].
the above code return err :

Inner dimensions of output shape must match inner dimensions of updates shape. Output: [3,4] updates: [4,4].

even when I set the two arr the same shape, it still returns the similar err.

but it works well in pytroch scatter_add.

I don’t know why and how to achieve the same in tensorflow as in pytorch.

Can you help me please?

I try to read the offical explantations and found out there are some confusing requirements within it. How can I get the same effect as catter_add

Asked By: Joe Jane

||

Answers:

inspired by a similar question and resolve
how to change torch.scatter_add to tensorflow function

I solved my own question by transform the original data into 1-D dimension(To satisfy the requirement of function in tensorflow)


def scatter_add(tensor, indices, updates):

    """
    according to some problems with using tf.tensor_scatter_nd_add, we firstly reshape to one-dimension
    """

    original_tensor = tensor
    indices_add = tf.range(0, indices.shape[-1])
    indices_add = repeat(indices_add, "n -> d n", d=indices.shape[0])
    indices = indices * indices.shape[-1]
    indices += indices_add

    tensor = tf.reshape(tensor, shape=[-1])
    indices = tf.reshape(indices, shape=[-1, 1])
    updates = tf.reshape(updates, shape=[-1])

    scatter = tf.tensor_scatter_nd_add(tensor, indices, updates)
    scatter = tf.reshape(scatter, shape=[original_tensor.shape[0], original_tensor.shape[1], -1])
    scatter = tf.squeeze(scatter)

    return scatter
Answered By: Joe Jane
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.