Error when trying to jit the computation of the Jacobian in JAX: "ValueError: Non-hashable static arguments are not supported"

Question:

This question is similar to the question here, but I cannot link with what I should alter.

I have a function

def elbo(variational_parameters, eps, a, b):
    ...
    return theta, _

elbo = jit(elbo, static_argnames=["a", "b"])

where variational_parameters is a vector (one-dimensional array) of length P, eps is a two-dimensional array of dimensions K by N, and a, b are fixed values.

The elbo has been successfully vmapped over the rows of eps, and has been jitted by setting by passing a and b to static_argnames, to return theta, which is a two-dimensional array of dimensions K by P.

I want to take the Jacobian of the output theta with respect to variational_parameters through the elbo function. The first value returned by

jacobian(elbo, argnums=0, has_aus=True)(variational_parameters, eps, a, b)

gives me a three-dimensional array of dimensions K by P by N. This is what I want. As soon as I try to jit this function

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

I get the error

ValueError: Non-hashable static arguments are not supported, which can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function elbo is non-hashable.

Any help would be greatly appreciated; thanks!

Asked By: hasco641

||

Answers:

Any parameters you pass to a JIT-compiled function will no longer be static, unless you explicitly mark them as such. So this line:

jit(jacobian(elbo, argnums=0, has_aus=True))(variational_parameters, eps, a, b)

Makes variational_parameters, eps, a, and b non-static. Then within the transformed function these non-static parameters are passed to this function:

elbo = jit(elbo, static_argnames=["a", "b"])

which means that you are attempting to pass non-static values as static arguments, which causes an error.

To fix this, you should mark the static parameters as static any time they enter a jit-compiled function. In your case it might look something like this:

jit(jacobian(elbo, argnums=0, has_aus=True),
    static_argnums=(2, 3))(variational_parameters, eps, a, b)
Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.