jax segment_sum along array dimension
Question:
I am fairly new to jax and have the following problem:
I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp.ops.segment_sum function. This works great for one array, but how can I generalize this approach to a batch of arrays? E.g:
import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
np.arange(9).reshape((3,3)),
np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
data=batch_of_matrixes[0],
segment_ids=indexes[0],
num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
[
[ 3 4 5],
[ 6 8 10]
],
[
[3 5 7],
[6 7 8]
]
]
If there is a more general way to do this than the obs.segment_* family, please also let me know. Thanks in advance for help and suggestions!
Answers:
JAX’s vmap
transformation is designed for exactly this kind of situation. In your case, you can use it like this:
@jax.vmap
def f(data, index):
return jax.ops.segment_sum(data, index, num_segments=2)
print(f(batch_of_matrixes, indexes))
# [[[ 3 4 5]
# [ 6 8 10]]
# [[ 3 5 7]
# [ 6 7 8]]]
For some more discussion of this, see JAX 101: Automatic Vectorization.
I am fairly new to jax and have the following problem:
I need to compute functions (sum/min/max maybe more complex stuff later) across an array given an index. To solve this problem I found the jnp.ops.segment_sum function. This works great for one array, but how can I generalize this approach to a batch of arrays? E.g:
import jax.numpy as jnp
indexes = jnp.array([[1,0,1],[0,0,1]])
batch_of_matrixes = jnp.array([
np.arange(9).reshape((3,3)),
np.arange(9).reshape((3, 3))
])
# The following works for one array but not multiple
jax.ops.segment_sum(
data=batch_of_matrixes[0],
segment_ids=indexes[0],
num_segments=2)
# How can I get this to work with the full dataset along the 0 dimension?
# Intended Outcome:
[
[
[ 3 4 5],
[ 6 8 10]
],
[
[3 5 7],
[6 7 8]
]
]
If there is a more general way to do this than the obs.segment_* family, please also let me know. Thanks in advance for help and suggestions!
JAX’s vmap
transformation is designed for exactly this kind of situation. In your case, you can use it like this:
@jax.vmap
def f(data, index):
return jax.ops.segment_sum(data, index, num_segments=2)
print(f(batch_of_matrixes, indexes))
# [[[ 3 4 5]
# [ 6 8 10]]
# [[ 3 5 7]
# [ 6 7 8]]]
For some more discussion of this, see JAX 101: Automatic Vectorization.