JAX – jitting functions: parameters vs "global" variables

Question:

I’ve have the following doubt about Jax. I’ll use an example from the official optax docs to illustrate it:

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    if i % 100 == 0:
      print(f'step {i}, loss: {loss_value}')

  return params

# Finally, we can fit our parametrized function using the Adam optimizer
# provided by optax.
optimizer = optax.adam(learning_rate=1e-2)
params = fit(initial_params, optimizer)

In this example, the function step uses the variable optimizer despite it not being passed within the function arguments (since the function is being jitted and optax.GradientTransformation is not a supported type). However, the same function uses other variables that are instead passed as parameters (i.e., params, opt_state, batch, labels). I understand that jax functions needs to be pure in order to be jitted, but what about input (read-only) variables. Is there any difference if I access a variable by passing it through the function arguments or if I access it directly since it’s in the step function scope? What if this variable is not constant but modified between separate step calls? Are they treated like static arguments if accessed directly? Or are they simply jitted away and so modifications of such parameters will not be considered?

To be more specific, let’s look at the following example:

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_learning_rate # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels)
    extra_learning_rate = 0.01 # does this affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels)

  return params

vs

def fit(params: optax.Params, optimizer: optax.GradientTransformation) -> optax.Params:
  opt_state = optimizer.init(params)
  extra_learning_rate = 0.1

  @jax.jit
  def step(params, opt_state, batch, labels, extra_lr):
    loss_value, grads = jax.value_and_grad(loss)(params, batch, labels)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    updates *= extra_lr # not really valid code, but you get the idea
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for i, (batch, labels) in enumerate(zip(TRAINING_DATA, LABELS)):
    extra_learning_rate = 0.1
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)
    extra_learning_rate = 0.01 # does this now affect the next `step` call?
    params, opt_state, loss_value = step(params, opt_state, batch, labels, extra_learning_rate)

  return params

From my limited experiments, they perform differently as the second step call doesn’t uses the new learning rates in the global case and also no ‘re-jitting’ happens, however I’d like to know if there’s any standard practice/rules I need to be aware of. I’m writing a library where performance is fundamental and I don’t want to miss some jit optimizations because I’m doing things wrong.

Asked By: Liuka

||

Answers:

During JIT tracing, JAX treats global values as implicit arguments to the function being traced. You can see this reflected in the jaxpr representing the function.

Here are two simple functions that return equivalent results, one with implicit arguments and one with explicit:

import jax
import jax.numpy as jnp

def f_explicit(a, b):
  return a + b

def f_implicit(b):
  return a_global + b

a_global = jnp.arange(5.0)
b = jnp.ones(5)

print(jax.make_jaxpr(f_explicit)(a_global, b))
# { lambda ; a:f32[5] b:f32[5]. let c:f32[5] = add a b in (c,) }

print(jax.make_jaxpr(f_implicit)(b))
# { lambda a:f32[5]; b:f32[5]. let c:f32[5] = add a b in (c,) }

Notice the only difference in the two jaxprs is that in f_implicit, the a variable comes before the semicolon: this is the way that jaxpr representations indicate the argument is passed via closure rather than via an explicit argument. But the computation generated by these two functions will be identical.

That said, one difference to be aware of is that when an argument passed by closure is a hashable constant, it will be treated as static within the traced function (similar when explicit arguments are marked static via static_argnums or static_argnames within jax.jit):

a_global = 1.0
print(jax.make_jaxpr(f_implicit)(b))
# { lambda ; a:f32[5]. let b:f32[5] = add 1.0 a in (b,) }

Notice in the jaxpr representation the constant value is inserted directly as an argument to the add operation. The explicit way to to get the same result for a JIT-compiled function would look something like this:

from functools import partial

@partial(jax.jit, static_argnames=['a'])
def f_explicit(a, b):
  return a + b
Answered By: jakevdp
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.