JAX performance problems

Question:

I am obviously not following best practices, but maybe that’s because I don’t know what they are. Anyway, my goal is to generate a tubular neighborhood about a curve in three dimensions. A curve is give by an array of length three f(t) = jnp.array([x(t), y(t), z(t)]).

Now, first we compute the unit tangent:

def get_uvec2(f):
  tanvec = jacfwd(f)
  return lambda x: tanvec(x)/jnp.linalg.norm(tanvec(x))

Next, we compute the derivative of the tangent:

def get_cvec(f):
  return get_uvec2(get_uvec2(f))

Third, we compute the orthogonal frame at a point:

def get_frame(f):
  tt = get_uvec2(f)
  tt2 = get_cvec(f)
  def first2(t):
    x = tt(t)
    y = tt2(t)
    tt3 = (jnp.cross(x, y))
    return jnp.array([x, y, tt3])
  return first2

which we use to generate a point in the circle around a given point:

def get_point(frame, s):
  v1 = frame[1, :]
  v2 = frame[2, :]
  return jnp.cos(s) * v1 + jnp.sin(s) * v2

And now we generate the point on the tubular neighborhood corresponding to a pair of parameters:

def get_grid(f, eps):
  ffunc = get_frame(f)
  def grid(t, s):
    base = f(t)
    frame = ffunc(t)
    return base + eps * get_point(frame, s)
  return grid

And finally, we put it all together:

def get_reg_grid(f, num1, num2, eps):
  plist = []
  tarray = jnp.linspace(start = 0.0, stop = 1.0, num = num1)
  sarray = jnp.linspace(start = 0.0, stop = 2 * jnp.pi, num = num2)
  g = get_grid(f, eps)
  for t in tarray:
    for s in sarray:
      plist.append(g(t, s))
  return jnp.vstack(plist)

Finally, use it to compute the tubular neighborhood around a circle in the xy-plane:

f1 = lambda x: jnp.array([jnp.cos(2 * jnp.pi * x), jnp.sin(2 * jnp.pi * x), 0.0])

fff = np.array(get_reg_grid(f1, 200, 200, 0.1))

The good news is that it all works. The bad news is that this computation takes well over an hour. Where did I go wrong?

Asked By: Igor Rivin

||

Answers:

JAX and numpy share one key rule-of-thumb for getting good performance: if you are writing for loops over array values, your code will probably be slow.

To make your code more performant, you should replace your loops with vectorized operations. One nice feature of JAX is jax.vmap, a vectorizing transform which makes this relatively easy. You can also use jax.jit to JIT-compile your function and get even faster execution.

Here’s a modified version of your get_reg_grid function that returns the same result with much faster execution:

import jax
from functools import partial

@partial(jax.jit, static_argnames=['f', 'num1', 'num2'])
def get_reg_grid(f, num1, num2, eps):
  tarray = jnp.linspace(start = 0.0, stop = 1.0, num = num1)
  sarray = jnp.linspace(start = 0.0, stop = 2 * jnp.pi, num = num2)
  g = get_grid(f, eps)
  g = jax.vmap(g, in_axes=(None, 0))
  g = jax.vmap(g, in_axes=(0, None))
  return jnp.vstack(g(tarray, sarray))

With this approach, your code executes in about 300 microseconds:

%timeit get_reg_grid(f1, 200, 200, 0.1).block_until_ready()
# 296 µs ± 157 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.