jax segment_sum along array dimension

Question:

I am fairly new to jax and have the following problem:
I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp.ops.segment_sum function. This works great for one array, but how can I generalize this approach to a batch of arrays? E.g:

import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
    np.arange(9).reshape((3,3)),
    np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
    data=batch_of_matrixes[0],
    segment_ids=indexes[0],
    num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
    [
        [ 3  4  5],
        [ 6  8 10]
    ],
    [
        [3  5  7],
        [6  7  8]
   ]
]

If there is a more general way to do this than the obs.segment_* family, please also let me know. Thanks in advance for help and suggestions!

Asked By: Simon P.

||

Answers:

JAX’s vmap transformation is designed for exactly this kind of situation. In your case, you can use it like this:

@jax.vmap
def f(data, index):
  return jax.ops.segment_sum(data, index, num_segments=2)

print(f(batch_of_matrixes, indexes))
# [[[ 3  4  5]
#   [ 6  8 10]]

#  [[ 3  5  7]
#   [ 6  7  8]]]

For some more discussion of this, see JAX 101: Automatic Vectorization.

Answered By: jakevdp
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.