Caching Behavior in JAX

Question:

I have a function f that takes in a boolean static argument flag and performs some computation based on it’s value. Below is a rough outline of this function.

@partial(jax.jit, static_argnames=['flag'])
def f(x, flag):
   # Preprocessing
   if flag:
      ...
   else:
      ...
   # Postprocessing

Each time f is called with a different value of flag, a recompilation of this function should be triggered. However, because flag is a boolean and can take on at most two values, it would be preferable if JAX would cache the compiled version of f for each of the possible values of flag and avoid recompilations.

In short, I would like JAX to compile f only two times when running following piece of code:

flag = True
for i in range(100):
   f(x, flag)
   flag = not flag

Is there a way to tell JAX not to throw away old compiled versions of f, each time it’s called with a new value of flag? And in general, are there any caching mechanisms implemented in JAX for such scenarios? (For instance if flag is an integer, but we know beforehand that it would only ever take k distinct values, and we would like to save the compiled version of f for each of these k values)

I know that I can use jax.lax.cond or jax.lax.switch to control the flow inside f and treat flag as a regular argument rather than a static one. But this would make the code much more bloated (and difficult to read) as there are several places within the body of f where I access flag. It would be much cleaner if I declared flag to be a static argument and then controled the caching behavior of jax.jit to avoid recompilations.

Asked By: Saeed Hedayatian

||

Answers:

If I understand your question correctly, then JAX by default behaves the way you would like it to behave. Each JIT-compiled function has an LRU cache of compilations based on the shape and dtype of dynamic arguments and the hash of static arguments. You can inspect the size of this cache using the _cache_size method of the compiled function. For example:

import jax
import jax.numpy as jnp
from functools import partial

@partial(jax.jit, static_argnames=['flag'])
def f(x, flag):
   if flag:
       return jnp.sin(x)
   else:
       return jnp.cos(x)

print(f._cache_size())
# 0

x = jnp.arange(10)
f(x, True)
print(f._cache_size())
# 1

f(x, False)
print(f._cache_size())
# 2

# Subsequent calls with the same flag value hit the cache:
flag = True
for i in range(100):
    f(x, flag)
    flag = not flag
print(f._cache_size())
# 2

Since the size of the x argument hasn’t changed, we get one cache entry for each value of flag, and the cached compilations are used in subsequent calls.

Note however if you change the shape or dtype of the dynamic argument, you get new cache entries:

x = jnp.arange(100)
for i in range(100):
    f(x, flag)
    flag = not flag
print(f._cache_size())
# 4

The reason this is necessary is that, in general, functions may change its behavior based on these static quantities.

Answered By: jakevdp
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.