Gathering entries in a matrix based on a matrix of column indices (tensorflow/numpy)

Question:

A little example to demonstrate what I need

I have a question about gathering in tensorflow. Let’s say I have a tensor of values (that I care about for some reason):

test1 = tf.round(5*tf.random.uniform(shape=(2,3)))

which gives me this output:

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 1., 2.],
       [4., 5., 0.]], dtype=float32)>

and I also have a tensor of indices column indices that I want to pick out on every row:

test_ind = tf.constant([[0,1,0,0,1],
                        [0,1,1,1,0]], dtype=tf.int64)

I want to gather this so that from the first row (0th row), I pick out items in column 0, 1, 0, 0, 1, and same for the second row.

So the output for this example should be:

<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[1., 1., 1., 1., 1.],
       [4., 5., 5., 5., 4.]], dtype=float32)>

My attempt

So I figured out a way to do this in general, I wrote the following function gather_matrix_indices() that will take in a tensor of values and a tensor of indices and do exactly what I specified above.

def gather_matrix_indices(input_arr, index_arr):
    row, _ = input_arr.shape
    
    li = []
    
    for i in range(row):
        li.append(tf.expand_dims(tf.gather(params=input_arr[i], indices=index_arr[i]), axis=0))
        
    return tf.concat(li, axis=0)

My Question

I’m just wondering, is there a way to do this using ONLY tensorflow or numpy methods? The only solution I could come up with is writing my own function that iterates through every row and gathers indices for all columns in that row. I have not had runtime issues yet but I would much rather utilize built-in tensorflow or numpy methods when possible. I’ve tried tf.gather before too, but I don’t know if this particular case is possible with any combination of tf.gather and tf.gather_nd. If anyone has a suggestion, I would greatly appreciate it.

Edit (08/18/22)

I would like to add an edit that in PyTorch, calling torch.gather() and setting dim=1 in the arguments will do EXACTLY what I wanted in this question. So if you’re familiar with both libraries, and you really need this functionality, torch.gather() can do this out of the box.

Asked By: AndrewJaeyoung

||

Answers:

You can use gather_nd() for this. It can look a bit tricky to get this working. Let me try to explain this with shapes.

We got test1 -> [2, 3] and test_ind_col_ind -> [2, 5]. test_ind_col_ind has only column indices, but you also need row indices to use gather_nd(). To use gather_nd() with a [2,3] tensor, we need to create a test_ind -> [2, 5, 2] sized tensor. The inner most dimension of this new test_ind correspond to individual indices you want to index from test1. Here we have the inner most dimension = 2 in the format (<row index>, <col index>). In other words, looking at the shape of test_ind,

[ 2 , 5 , 2 ]
    |     |
    V     |
  (2,5)   |       <- The size of the final tensor   
          V
         (2,)     <- The full index to a scalar in your input tensor
import tensorflow as tf

test1 = tf.round(5*tf.random.uniform(shape=(2,3)))
print(test1)

test_ind_col_ind = tf.constant([[0,1,0,0,1],
                        [0,1,1,1,0]], dtype=tf.int64)[:, :, tf.newaxis]

test_ind_row_ind = tf.repeat(tf.range(2, dtype=tf.int64)[:, tf.newaxis, tf.newaxis], 5, axis=1)

test_ind = tf.concat([test_ind_format, test_ind], axis=-1)

res = tf.gather_nd(indices=test_ind, params=test1)
Answered By: thushv89
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.