Getting a type error while using fori_loop with JAX

Question:

I’m developing a code using JAX, and I wanted to JIT some parts of that had big loops. I didn’t want the code to be unrolled so I used fori_loop, but I’m getting an error and can’t figure out what I am doing wrong.

The error is:

  self.arr = self.arr.reshape(new_shape+new_shape)
TypeError: 'aval_method' object is not callable

I was able to reduce the code to the following:

import jax.numpy as jnp
import jax

class UB():
    def __init__(self, arr, new_shape):

        self.arr = arr
        self.shape = new_shape
        if type(arr) is not object:
            self.arr = self.arr.reshape(new_shape+new_shape)

        
    def _tree_flatten(self):
        children = (self.arr,)  # arrays / dynamic values
        aux_data = {
            'new_shape': self.shape            
        }  # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)


class UM():
    def __init__(self, arr, r=None):

        self.arr = arr
        self.r = tuple(r)
    
    def _tree_flatten(self):
        children = (self.arr,)  # arrays / dynamic values
        aux_data = {
            'r': self.r
        }  # static values
        return (children, aux_data)

    @classmethod
    def _tree_unflatten(cls, aux_data, children):
        return cls(*children, **aux_data)



for C in [UB, UM]:
    jax.tree_util.register_pytree_node(
        C,
        C._tree_flatten,
        C._tree_unflatten,
    )


def s_w(ub, ums):
    e  = jnp.identity(2)
    u = UM(e, [2])
    ums[0] = u
    return ub, ums

def s_c(t, uns):
    n = 20
    ums = []
    for un in uns:
        ums.append(UM(un, [2]))

    tub = UB(t.arr, t.r)
    
    s_loop_body = lambda i,x: s_w( ub=x[0], ums=x[1])
    
    tub, ums = jax.lax.fori_loop(0, n, s_loop_body, (tub, ums))
    # for i in range(n):
    #     tub, ums = s_loop_body(i, (tub, ums))

    return jnp.array([u.arr.flatten() for u in ums])


uns = jnp.array([jnp.array([1, 2, 3, 4]) for _ in range(6)])
t = UM(jnp.array([1, 0, 0, 1]), r=[2])
uns = s_c(t, uns)

Has anyone encountered this issue or can explain how to fix it?

Asked By: Alon Kukliansky

||

Answers:

The issue is discussed here: https://jax.readthedocs.io/en/latest/pytrees.html#custom-pytrees-and-initialization

Namely, in JAX pytrees are used as general containers, and are sometimes initialized with abstract values or other place-holders, and so you cannot assume that arguments to a custom PyTree will be of array type. You might account for this by doing something like the following:

class UB():
    def __init__(self, arr, new_shape):
        self.arr = arr
        self.shape = new_shape
        if isinstance(arr, jnp.ndarray):
            self.arr = self.arr.reshape(new_shape+new_shape)

When I run your code with this modification, it gets past the error you asked about, but unfortunately does trigger another error due to the body function of the fori_loop not having a valid signature (namely, the arr attributes of the ums have different shapes on input and output, which is not supported by fori_loop).

Hopefully this gets you on the path toward working code!

Answered By: jakevdp
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.