Returning a distribution object from a jittable function

Question:

I want to create a jittable function that outputs a distrax distribution object. For instance:

import distrax
import jax
import jax.numpy as jnp

def f(x):
   dist = distrax.Categorical(logits=jnp.sin(x))
   return dist

jit_f = jax.jit(f)
a = jnp.array([1,2,3])
dist = jit_f(a)

Currently this code gives me the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "F:jax_envlibsite-packagesjax_srctraceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "F:jax_envlibsite-packagesjax_srcapi.py", line 628, in cache_miss
    out = tree_unflatten(out_pytree_def, out_flat)
  File "F:jax_envlibsite-packagesjax_srctree_util.py", line 75, in tree_unflatten
    return treedef.unflatten(leaves)
  File "F:jax_envlibsite-packagesdistrax_srcutilsjittable.py", line 40, in tree_unflatten
    obj = cls(*args, **kwargs)
  File "F:jax_envlibsite-packagesdistrax_srcdistributionscategorical.py", line 60, in __init__
    self._logits = None if logits is None else math.normalize(logits=logits)
  File "F:jax_envlibsite-packagesdistrax_srcutilsmath.py", line 72, in normalize
    return jax.nn.log_softmax(logits, axis=-1)
  File "F:jax_envlibsite-packagesjax_srctraceback_util.py", line 162, in reraise_with_filtered_traceback
    return fun(*args, **kwargs)
  File "F:jax_envlibsite-packagesjax_srcapi.py", line 618, in cache_miss
    keep_unused=keep_unused))
  File "F:jax_envlibsite-packagesjaxcore.py", line 2031, in call_bind_with_continuation
    top_trace = find_top_trace(args)
  File "F:jax_envlibsite-packagesjaxcore.py", line 1122, in find_top_trace
    top_tracer._assert_live()
  File "F:jax_envlibsite-packagesjaxinterpreterspartial_eval.py", line 1486, in _assert_live
    raise core.escaped_tracer_error(self, None)
jax._src.traceback_util.UnfilteredStackTrace: jax._src.errors.UnexpectedTracerError: Encountered an unexpected tracer. A function transformed by JAX had a side effect, allowing for a reference to an intermediate value with type float32[3] wrapped in a DynamicJaxprTracer to escape the scope of the transformation.
JAX transformations require that functions explicitly return their outputs, and disallow saving intermediate values to global state.
The function being traced when the value leaked was f at <stdin>:1 traced for jit.
------------------------------
The leaked intermediate value was created on line <stdin>:2 (f).
------------------------------
When the value was created, the final 5 stack frames (most recent last) excluding JAX-internal frames were:
------------------------------
<stdin>:1 (<module>)
<stdin>:2 (f)
------------------------------

To catch the leak earlier, try setting the environment variable JAX_CHECK_TRACER_LEAKS or using the `jax.checking_leaks` context manager.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError

I thought that using dist = jax.block_until_ready(dist) inside f could fix the problem, but it doesn’t.

Asked By: Saeed Hedayatian

||

Answers:

This looks like the bug in distrax v0.1.2 reported in https://github.com/deepmind/distrax/issues/162. This wass fixed by https://github.com/deepmind/distrax/pull/177, which is part of the distrax v0.1.3 release.

To fix the issue, you should update to distrax v0.1.3 or later.

Answered By: jakevdp
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.