Rewriting for loop with jax.lax.scan

Question:

I’m having troubles understanding the JAX documentation. Can somebody give me a hint on how to rewrite simple code like this with jax.lax.scan?

numbers = numpy.array( [ [3.0, 14.0], [15.0, -7.0], [16.0, -11.0] ])
evenNumbers = 0
for row in numbers:
      for n in row:
         if n % 2 == 0:
            evenNumbers += 1
Asked By: pepazdepa

||

Answers:

Assuming a solution should demonstrate the concepts rather than optimize the example shown, the function to be jax.lax.scanned must match the expected signature and any dynamic condition has to be replaced with jax.lax.cond. The code below is the closest to the original I could think of, but please be aware that I’m anything but an jaxpert.

import jax
import jax.numpy as jnp

def f(carry, row):

    even = 0
    for n in row:
        even += jax.lax.cond(n % 2 == 0, lambda: 1, lambda: 0)

    return carry + even, even

numbers = jnp.array([[3.0, 14.0], [15.0, -7.0], [16.0, -11.0]])
jax.lax.scan(f, 0, numbers)

Output

(DeviceArray(2, dtype=int32, weak_type=True),
 DeviceArray([1, 0, 1], dtype=int32, weak_type=True))
Answered By: Michael Szczesny
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.