How to vmap over cho_solve and cho_factor?

Question:

The following error appears because of the last line of code below:

jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected…

The problem arose with the bool function.

It looks like it is due to the lower return value from cho_factor, which _cho_solve (note underscore) requires as static.

I’m new to jax, so I was hoping that vmap-ing cho_factor into cho_solve would just work. What have I done wrong here?

import jax

key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))

matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)

k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
Asked By: logan

||

Answers:

So I didn’t manage to get cho_factor and cho_solve working, but worked around it using cholesky and solve_triangular:

  cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
  solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))

  L = cholesky(k_y, True)
  result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
Answered By: logan

The issue is that in each case, lower is a static scalar that should not be mapped over. So if you specify in_axes and out_axes so that lower is mapped over axis None, the vmap should work:

cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))
Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.