How to un-JIT-compile a function which is called by a JIT-compiled function

Question:

I have a script which performs some calculations on some given arrays. These calculations are performed thousands of times, so naturally I want to use JAX’s JIT decorator to speed up these calculations. I have several functions which are called from some "master function," and I want to JIT-compile the master function. However, there is one function I don’t want to be JIT-compiled because it can’t be made JIT compatible (or, at least, I don’t know how to make it so). Below is an example:

import jax
from functools import partial
import numpy as np


def function(params, X):
    # create an array of zeros with same length as x (not X)
    # set values to -1 if corresponding value of x (not X) is between specified limits
    # otherwise set values to zero
    
    values = jax.numpy.zeros(len(x))
    
    for i in range(len(x)):
        if x[i] < params[1] and x[i] > params[0]:
            values = values.at[i].set(-1)
            
    X.val = values
    return X


# @jax.jit
def master_function(params):
    # vmap previous function onto x
    
    partial_function = partial(function, params)
    return jax.vmap(partial_function)(x)


# define some variables
params = [4, 6]
x = np.linspace(0, 10, 100)

# run master function
new_x = master_function(params)

# print new_x array
print(new_x)

In this simple example, I have some array x. I want to then create a copy of that array, called new_x, where each value is either a 0 or a -1. If a value in x is between some bounds (specified by params), its value in new_x should be -1, and zero otherwise. When I don’t JIT-compile master_function(), this script works perfectly. However, when I JIT-compile master_function, and, by extension, function, I get the following error:

Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `bool` function. 
The error occurred while tracing the function master_function at temp.py:28 for jit. This concrete value was not available in Python because it depends on the value of the argument 'params'.

I understand that this error is caused by the way JIT-compilation works, so I want to un-JIT-compile function() while still JIT-compiling master_function if possible.

Asked By: PositronJon

||

Answers:

You cannot normally1 call an un-jitted function from within a jit-compiled function. In your case it looks like the best solution is to rewrite your function in a way that will be JIT-compatible. You can replace your for-loop with this:

values = jnp.where((x < params[1]) & (x > params[0]), -1.0, 0.0)

Side-note, it looks like you’re doing in-place modifications of the val attribute of a batch tracer, which is not a supported operation and will probably have unexpected consequences. I’d suggest writing your code using standard operations, but the intent of your code is not clear to me so I’m not sure what change to suggest.


1 this actually is possible using pure_callback, but probably is not what you want because it comes with performance penalties.

Answered By: jakevdp
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.