JAX with JIT and custom differentiation

Question:

I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I need to be able to differentiate the B-spline in JAX (only in the input argument and not in the knots or the integer order (of course!)).

I can easily use jax.custom_vjp but not when JIT is used as it is in numpyro. I looked at the following:

  1. https://github.com/google/jax/issues/1142
  2. https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html

and it seems like the best hope is to use a callback. Though, I cannot figure out entirely how that would work.
At https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support

the TensorFlow example with reverse mode autodiff seem not to use JIT.

The example

Here is Python code that works without JIT (see the b_spline_basis() function):

from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax

doubleArray = npt.NDArray[np.double]

# see
#   https://stackoverflow.com/q/74699053/5861244
#   https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray:  # type: ignore[no-any-unimported]
    out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))

    for col_index in range(out.shape[1] - 1):
        scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
        if scale != 0:
            out[:, col_index] = -deriv_basis[:, col_index + 1] / scale

    for col_index in range(1, out.shape[1]):
        scale = spline.t[col_index + spline.k] - spline.t[col_index]
        if scale != 0:
            out[:, col_index] += deriv_basis[:, col_index] / scale

    return float(spline.k) * out


def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray:  # type: ignore[no-any-unimported]
    if deriv == 0:
        return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
    elif spline.k <= 0:
        return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))

    return _b_spline_deriv_inner(
        spline=spline,
        deriv_basis=_b_spline_eval(
            BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
        ),
    )


@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
        :, 1:
    ]


def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        _b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
        _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
    )


def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)


b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)

if __name__ == "__main__":
    # tests

    knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
    x = np.array([0.1, 0.5, 0.9])
    order = 3

    def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
        weights = jax.numpy.arange(1, basis.shape[1] + 1)

        def test_func(x: doubleArray) -> doubleArray:
            return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights))  # type: ignore[no-any-return]

        assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
        assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))

    deriv0 = np.transpose(
        np.array(
            [
                0.684,
                0.166666666666667,
                0.00133333333333333,
                0.096,
                0.444444444444444,
                0.0355555555555555,
                0.004,
                0.351851851851852,
                0.312148148148148,
                0,
                0.037037037037037,
                0.650962962962963,
            ]
        ).reshape(-1, 3)
    )

    deriv1 = np.transpose(
        np.array(
            [
                2.52,
                -1,
                -0.04,
                1.68,
                -0.666666666666667,
                -0.666666666666667,
                0.12,
                1.22222222222222,
                -2.29777777777778,
                0,
                0.444444444444444,
                3.00444444444444,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv0, deriv1, deriv=0)

    deriv2 = np.transpose(
        np.array(
            [
                -69.6,
                4,
                0.8,
                9.6,
                -5.33333333333333,
                5.33333333333333,
                2.4,
                -2.22222222222222,
                -15.3777777777778,
                0,
                3.55555555555556,
                9.24444444444445,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv1, deriv2, deriv=1)

    deriv3 = np.transpose(
        np.array(
            [
                504,
                -8,
                -8,
                -144,
                26.6666666666667,
                26.6666666666667,
                24,
                -32.8888888888889,
                -32.8888888888889,
                0,
                14.2222222222222,
                14.2222222222222,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv2, deriv3, deriv=2)

Answers:

The best way to accomplish this is probably using a combination of custom_jvp and jax.pure_callback.

Unfortunately, pure_callback is relatively new and does not have great documentation yet, but you can find examples of its use in the JAX user forums (for example here).

Copied here for posterity, this is an example of computing the sine and cosine via numpy callbacks in jit-compatible code with custom JVP rules for autodiff.

import jax
import numpy as np
jax.config.update('jax_enable_x64', True)

@jax.custom_jvp
def np_sin(x):
  # Compute the sine by calling-back to np.sin on the host.
  return jax.pure_callback(np.sin, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)

@np_sin.defjvp
def _np_sin_jvp(primals, tangents):
  x, = primals
  dx, = tangents
  return np_sin(x), np_cos(x) * dx  #d sin(x) = cos(x) dx

@jax.custom_jvp
def np_cos(x):
  # Compute the cosine by calling-back to np.cos on the host.
  return jax.pure_callback(np.cos, jax.ShapeDtypeStruct(np.shape(x), np.float64), x)

@np_cos.defjvp
def _np_cos_jvp(primals, tangents):
  x, = primals
  dx, = tangents
  return np_cos(x), -np_sin(x) * dx  # d cos(x) = -sin(x) dx


print(np_sin(1.0))
# 0.8414709848078965
print(np_cos(1.0))
# 0.5403023058681398
print(jax.jit(jax.grad(np_sin))(1.0))
# 0.5403023058681398

Note that since pure_callback operates by sending data back to the host, it will generally have a lot of overhead on accelerators like GPU and TPU, although in a single-CPU setting this kind of approach can perform well.

Answered By: jakevdp

This is a follow up to the answer provided by jakevdp (the accepted and true answer). In the concrete example, the b_spline_basis() function and the _fwd and _bwd functions can be changed to

@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return jax.pure_callback(
        lambda x: _b_spline_eval(
            spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv
        )[:, 1:],
        jax.ShapeDtypeStruct((x.shape[0], knots.shape[0] - order - 2), np.float64),
        x,
    )


def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        b_spline_basis(knots=knots, order=order, deriv=deriv, x=x),
        jax.pure_callback(
            lambda x: _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
            jax.ShapeDtypeStruct((x.shape[0], knots.shape[0] - order - 2), np.float64),
            x=x,
        ),
    )


def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)

then the test passes even after

assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))

was changed to

assert np.allclose(jax.jit(test_func)(x), np.sum(np.dot(basis, weights)))
assert np.allclose(jax.jit(jax.grad(test_func))(x), np.dot(partials, weights))

Here is the complete code for the record:

from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax

jax.config.update("jax_enable_x64", True)


doubleArray = npt.NDArray[np.double]

# see
#   https://stackoverflow.com/q/74699053/5861244
#   https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray:  # type: ignore[no-any-unimported]
    out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))

    for col_index in range(out.shape[1] - 1):
        scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
        if scale != 0:
            out[:, col_index] = -deriv_basis[:, col_index + 1] / scale

    for col_index in range(1, out.shape[1]):
        scale = spline.t[col_index + spline.k] - spline.t[col_index]
        if scale != 0:
            out[:, col_index] += deriv_basis[:, col_index] / scale

    return float(spline.k) * out


def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray:  # type: ignore[no-any-unimported]
    if deriv == 0:
        return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
    elif spline.k <= 0:
        return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))

    return _b_spline_deriv_inner(
        spline=spline,
        deriv_basis=_b_spline_eval(
            BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
        ),
    )


@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
    return jax.pure_callback(
        lambda x: _b_spline_eval(
            spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv
        )[:, 1:],
        jax.ShapeDtypeStruct((x.shape[0], knots.shape[0] - order - 2), np.float64),
        x,
    )


def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
    spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
    return (
        b_spline_basis(knots=knots, order=order, deriv=deriv, x=x),
        jax.pure_callback(
            lambda x: _b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
            jax.ShapeDtypeStruct((x.shape[0], knots.shape[0] - order - 2), np.float64),
            x=x,
        ),
    )


def b_spline_basis_bwd(
    knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
    return (jax.numpy.sum(partials * grad, axis=1),)


b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)


if __name__ == "__main__":
    # tests

    knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
    x = np.array([0.1, 0.5, 0.9])
    order = 3

    def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
        weights = jax.numpy.arange(1, basis.shape[1] + 1)

        def test_func(x: doubleArray) -> doubleArray:
            return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights))  # type: ignore[no-any-return]

        assert np.allclose(jax.jit(test_func)(x), np.sum(np.dot(basis, weights)))
        assert np.allclose(jax.jit(jax.grad(test_func))(x), np.dot(partials, weights))

    deriv0 = np.transpose(
        np.array(
            [
                0.684,
                0.166666666666667,
                0.00133333333333333,
                0.096,
                0.444444444444444,
                0.0355555555555555,
                0.004,
                0.351851851851852,
                0.312148148148148,
                0,
                0.037037037037037,
                0.650962962962963,
            ]
        ).reshape(-1, 3)
    )

    deriv1 = np.transpose(
        np.array(
            [
                2.52,
                -1,
                -0.04,
                1.68,
                -0.666666666666667,
                -0.666666666666667,
                0.12,
                1.22222222222222,
                -2.29777777777778,
                0,
                0.444444444444444,
                3.00444444444444,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv0, deriv1, deriv=0)

    deriv2 = np.transpose(
        np.array(
            [
                -69.6,
                4,
                0.8,
                9.6,
                -5.33333333333333,
                5.33333333333333,
                2.4,
                -2.22222222222222,
                -15.3777777777778,
                0,
                3.55555555555556,
                9.24444444444445,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv1, deriv2, deriv=1)

    deriv3 = np.transpose(
        np.array(
            [
                504,
                -8,
                -8,
                -144,
                26.6666666666667,
                26.6666666666667,
                24,
                -32.8888888888889,
                -32.8888888888889,
                0,
                14.2222222222222,
                14.2222222222222,
            ]
        ).reshape(-1, 3)
    )
    test_jax(deriv2, deriv3, deriv=2)
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.