Issues with non-hashable static arguments when forming

Question:

I have a vector-jacobian product that I want to compute.

The function func takes four arguments, the final two of which are static:

def func(variational_params, e, A, B):
    ...
    return model_params, dlogp, ...

The function jits perfectly fine via

func_jitted = jit(func, static_argnums=(2, 3))

The primals are the variational_params, and the cotangents are dlogp (the second output of the function).

Calculating the vector-jacobian product naively (by forming the jacobian) works fine:

jacobian_func = jacobian(func_jitted, argnums=0, has_aux=True)
jacobian_jitted = jit(jacobian_func, static_argnums=(2, 3))
jac, func_output = jacobian_jitted(variational_params, e, A, B)
naive_vjp = func_output.T @ jac 

When trying to form the vjp in an efficient manner via

f_eval, vjp_function, aux_output = vjp(func_jitted, variational_params, e, A, B, has_aux=True)

I get the following error:

ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 2) of type <class 'jax.interpreters.ad.JVPTracer'> for function func is non-hashable.

I am a little confused as the function func jitted perfectly fine… there is no option for adding static_argnums to the vjp function, so I am not too sure what this means.

Asked By: hasco641

||

Answers:

For higher-level transformation APIs like jit, JAX generally provides a mechanism like static_argnums or argnums to allow specification of static vs. dynamic variables.

For lower-level transformation routines like jvp and vjp, these mechanisms are not provided, but you can still accomplish the same thing by passing partially-evaluated functions. For example:

from functools import partial

f_eval, vjp_function, aux_output = vjp(partial(func_jitted, A=A, B=B), variational_params, e, has_aux=True)

This is effectively how transformation parameters like argnums and static_argnums are implemented under the hood.

Answered By: jakevdp