Issue with jax.lax.scan

Question:

I am supposed to use Jax.lax.scan instead of a for loop with 100 iterations at line 22. I am supposed to update S and append it to S_list. I am unsure how to fix the jax.lax.scan. The error that keeps popping up is missing the required XS. When I put a value for XS it says that my length argument doesn’t line up with the axis sizes. Here is my code. Can you help me?

Asked By: Bob Zofaul

||

Answers:

You’re not calling scan with the correct signature. You can find more information on the call signature in the jax.lax.scan docs. It makes clear, for example, that your step function must accept two arguments and return two arguments.

From looking at your code, it looks like you’re intending to do something like this:

@jax.jit
def simulate():
  key = jax.random.PRNGKey(0)
  def step(S, _):
    dZ = jax.random.normal(key, shape=(S.size,)) * jnp.sqrt(dt)
    dS = r * S  * dt + σ  * S  * dZ
    return S + dS, S
  S0 = jnp.ones(20000)
  _, S_array = jax.lax.scan(step, S0, xs=None, length=m)
  return S_array

In particular, from the docs you can see that the S_list.append(...) and S_array = jnp.stack(S_list) are effectively part of the scan function itself, so you don’t have to do that yourself after calling it.

Hope that helps!

Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.