Nested vmap in pmap – JAX

Question:

I currently can run simulations in parallel on one GPU using vmap. To speed things up, I want to batch the simulations over multiple GPU devices using pmap. However, when pmapping the vmapped function I get a tracing error.

The code I use to get a trajectory state is:

traj_state = vmap(run_trajectory, in_axes=(0, None, 0))(sim_state, timings, lambda_array)
                                                                        

where lambda_array parameterises each simulation, which is run by the function run_trajectory which runs a single simulation. I then try to nest this inside a pmap:

pmap(vmap(run_trajectory, in_axes=(0, None, 0)),in_axes=(0, None, 0))(reshaped_sim_state, timings, reshaped_lambda_array)                                                                                       

In doing so I get the error:

While tracing the function run_trajectory for pmap, this concrete value was not available in Python because it depends on the value of the argument 'timings'.

I’m quite new to JAX and although there are documentations on errors with traced values, I’m not
very sure on how to navigate this problem.

Asked By: Anton B

||

Answers:

vmap and pmap have slightly different APIs when it comes to in_axes. In vmap, setting in_axes=None causes inputs to be unmapped and static (i.e. un-traced), while in pmap even inputs with in_axes=None will be unmapped but still traced:

from jax import vmap, pmap
import jax.numpy as jnp

def f(x, condition):
  # requires untraced condition:
  return x if condition else x + 1

x = jnp.arange(4)
vmap(f, in_axes=(0, None))(x, True)
# Array([0, 1, 2, 3], dtype=int32)

pmap(f, in_axes=(0, None))(x, True)
# ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: 

To ensure that your variable is untraced in pmap, you can partially evaluate the function; for example:

from functools import partial

vmap(partial(f, condition=True), in_axes=0)(x)
# Array([0, 1, 2, 3], dtype=int32)

pmap(partial(f, condition=True), in_axes=0)(x)
# Array([0, 1, 2, 3], dtype=int32)

In your case, applying this solution might look like this:

def run(sim_state, lambda_array, timings=timings):
  return run_trajectory(sim_state, timings, lambda_array)

vmap(run)(sim_state, lambda_array)

pmap(vmap(run))(reshaped_sim_state, reshaped_lambda_array)
Answered By: jakevdp

I can seemingly avoid this problem by passing the timing values prior to vmapping using partial, that is:

run_trajectory = partial(run_trajectory, starting_time=timings)

traj_state = pmap(vmap(run_trajectory, in_axes=(0, 0)))(reshaped_sim_state, reshaped_lambda_array)
Answered By: Anton B
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.