How to vmap over cho_solve and cho_factor?
Question:
The following error appears because of the last line of code below:
jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected…
The problem arose with the bool
function.
It looks like it is due to the lower
return value from cho_factor
, which _cho_solve
(note underscore) requires as static.
I’m new to jax, so I was hoping that vmap-ing cho_factor
into cho_solve
would just work. What have I done wrong here?
import jax
key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))
matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)
k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
Answers:
So I didn’t manage to get cho_factor
and cho_solve
working, but worked around it using cholesky
and solve_triangular
:
cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))
L = cholesky(k_y, True)
result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
The issue is that in each case, lower
is a static scalar that should not be mapped over. So if you specify in_axes
and out_axes
so that lower
is mapped over axis None
, the vmap
should work:
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))
The following error appears because of the last line of code below:
jax.errors.ConcretizationTypeError Abstract tracer value encountered where concrete value is expected…
The problem arose with the
bool
function.
It looks like it is due to the lower
return value from cho_factor
, which _cho_solve
(note underscore) requires as static.
I’m new to jax, so I was hoping that vmap-ing cho_factor
into cho_solve
would just work. What have I done wrong here?
import jax
key = jax.random.PRNGKey(0)
k_y = jax.random.normal(key, (100, 10, 10))
y = jax.random.normal(key, (100, 10, 1))
matmul = jax.vmap(jax.numpy.matmul)
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor)
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve)
k_y = matmul(k_y, jax.numpy.transpose(k_y, (0, 2, 1)))
chol, lower = cho_factor(k_y)
result = cho_solve((chol, lower), y)
So I didn’t manage to get cho_factor
and cho_solve
working, but worked around it using cholesky
and solve_triangular
:
cholesky = jax.vmap(jax.scipy.linalg.cholesky, in_axes=(0, None))
solve_tri = jax.vmap(jax.scipy.linalg.solve_triangular, in_axes=(0, 0, None, None))
L = cholesky(k_y, True)
result2 = solve_tri(L, solve_tri(L, y, 0, True), 1, True)
The issue is that in each case, lower
is a static scalar that should not be mapped over. So if you specify in_axes
and out_axes
so that lower
is mapped over axis None
, the vmap
should work:
cho_factor = jax.vmap(jax.scipy.linalg.cho_factor, out_axes=(0, None))
cho_solve = jax.vmap(jax.scipy.linalg.cho_solve, in_axes=((0, None), 0))