How to optimize and speed up this matrices multiplication in Python

Question:

According to the gradient equation, matrices multiplication is given by
enter image description here

where both @ and * are needed. Here is the code if readers are interested:

# parameters
beta     =  0.98 
alpha    =  0.03
delta    =  0.1
T        =  1000
loop     =  1
dif      =  1
tol      =  1e-8

kss      =  ((1 / beta - (1 - delta)) / alpha)**(1 / (alpha - 1))
k        =  np.linspace(0.5 * kss, 1.8 * kss, T)

k_reshaped   =  k.reshape(-1, 1)
c            =  k_reshaped ** alpha + (1 - delta) * k_reshaped - k
c[c<0]       =  1e-11
c            =  np.log(c)
beta_square  =  beta**2

# multiplication
I   =  np.identity(T)
E   =  np.ones(T)[:,None]
Q2  =  I

while np.any(dif > tol) and loop < 200:
    J   =  beta * Q2
    B   =  inv(I - J)

    Q3  =  np.zeros([T,T])
    ini =  np.argmax(c + (B @ (J * c) @ E).flatten(),axis=1)
    Q3[np.arange(T),ini]  =  1


    gB  =  2 * B @ (J * c @ E) @ (beta * Q2 * c @ E + B @ (np.linalg.matrix_power(I - J, 2) * c @ E)).T / beta_square
    B   += 0.1 * gB

    dif  =  np.max(np.absolute(Q3 - Q2))
    kcQ  =  k[ini]

    Q2   =  Q3
    loop += 1

Basically, it is folloing the stochastic gradient descent algorithm, matrix B is initialized by B = inv(I - J) and evolving by B += 0.1 * gB, J varying along with Q2, and Q2 needs to be determined in each iteration. However Q2 is a sparse matrix each column only has one one and the rest being zero, in the code it is like:

ini =  np.argmax(c + (B @ (J * c) @ E).flatten(),axis=1)
Q3[np.arange(T),ini]  =  1
...
Q2  =  Q3

The code currently demonstrates a 1000 by 1000 matrices operation, could this be optimized and run even fatster?

Asked By: Zuba Tupaki

||

Answers:

Here is some improvements:

  • beta * Q2 is computed twice, J can be used instead the second time.
  • J * c is also computed multiple time while it can be done once. The same for I - J.
  • B @ (J * c) @ E and B @ (J * c @ E) are mathematically equivalent, but the later is faster in your case and can also be computed once.
  • CPython optimizes (almost) nothing and Numpy performs operations eagerly so doing 0.1 * (2 * Matrix / beta_square) actually compute a new matrix M2 = 2 * Matrix, then a new one M3 = M2 / beta_square, then another one M4 = 0.1 * M4. Creating many temporary matrices like this is expensive because it is a memory-bound operation and the memory bandwidth is pretty limited on modern machines (compared to the computing power), not to mention filling new temporary arrays is generally slower than already filled ones (because of virtual memory and more specifically page faults). Thus, it is faster to do (0.1 * 2 / beta_square) * Matrix (since multiplying float is much faster than multiplying big matrices).
  • Some basic operations like np.argmax(c + tmp3.flatten(), axis=1) or np.max(np.absolute(Q3 - Q2)) can be easily accelerated using Numba. In fact, most
  • In-place operations are generally faster than out-of-place ones (again, because of expensive temporary arrays). You can use them thanks to the out parameter of basic functions (e.g. np.multiply(A, B, out=C)). That being said, the benefit is quite small here since inv takes a significant time.
  • Assuming B is not needed in the end of the loop, you can use np.linalg.solve instead as mentioned by Homer512. Solving a system is significantly faster for large matrices (O(n**3) versus O(n**2)) and often more accurate. See Don’t invert that matrix. For example, inv(I-J) @ b can be replaced by solve(I-J, b). The benefit of using solve is not so big though in your specific use-case because of the sparse I-J matrix.
  • If B is actually used, in the end of the loop, then this is a bit more complex. Numba can help to write a relatively fast matrix inversion specifically for sparse matrices like in your use-case (since the one of Scipy turns out to be pretty slow).
  • np.linalg.matrix_power(tmp0, 2) * c can also be optimized in Numba for sparse matrices.

Here is a (barely tested) implementation using Numba:

@nb.njit('(float64[:,::1], float64[::1])', parallel=True)
def compute_ini(a, b):
    n, m = a.shape
    assert b.size == m and m > 0
    res = np.empty(n, np.int64)
    for i in nb.prange(n):
        max_val, max_pos = a[i, 0] + b[0], 0
        for j in range(1, m):
            val = a[i, j] + b[j]
            if val > max_val:
                max_val = val
                max_pos = j
        res[i] = max_pos
    return res

@nb.njit('(float64[:,::1], float64[:,::1])', parallel=True)
def max_abs_diff(a, b):
    return np.max(np.absolute(a - b))

# Utility function for invert_sparse_matrix
@nb.njit
def invert_sparse_matrix_subkernel(b, out, i1, n, eps):
    for i2 in range(n):
        if i2 != i1:
            scale = b[i2, i1]
            if abs(scale) >= eps:
                for j in range(n):
                    b[i2, j] -= scale * b[i1, j]
                    out[i2, j] -= scale * out[i1, j]

@nb.njit('(float64[:,::1], float64[:,::1])', parallel=True)
def invert_sparse_matrix(a, out):
    eps = 1e-14
    n, m = a.shape
    assert n == m and out.shape == (n,n)

    b = np.empty((n, n))

    for i in nb.prange(n):
        out[i, :i] = 0.0
        out[i, i] = 1.0
        out[i, i+1:] = 0.0
        b[i, :] = a[i, :]

    for i1 in range(n):
        scale = 1.0 / b[i1, i1]
        if abs(scale) < eps:
            b[i1, :].fill(0.0)
            out[i1, :].fill(0.0)
            invert_sparse_matrix_subkernel(b, out, i1, n, eps)
        elif abs(scale-1.0) < eps:
            invert_sparse_matrix_subkernel(b, out, i1, n, eps)
        else:
            b[i1, :] *= scale
            out[i1, :] *= scale
            invert_sparse_matrix_subkernel(b, out, i1, n, eps)

@nb.njit('(float64[:,::1], float64[:,::1], float64[:,::1])', parallel=True)
def sparse_square_premult(a, premult, out):
    eps = 1e-14
    n, m = a.shape
    assert n == m
    assert premult.shape == (n, n)
    assert out.shape == (n, n)

    for i in nb.prange(n):
        out[i, :] = 0.0
        for j in range(n):
            if abs(a[i, j]) >= eps:
                for k in range(n):
                    out[i, k] += a[i, j] * a[j, k]
        out[i, :] *= premult[i, :]

def compute_numba():
    # parameters
    beta     =  0.98 
    alpha    =  0.03
    delta    =  0.1
    T        =  1000
    loop     =  1
    dif      =  1
    tol      =  1e-8

    kss      =  ((1 / beta - (1 - delta)) / alpha)**(1 / (alpha - 1))
    k        =  np.linspace(0.5 * kss, 1.8 * kss, T)

    k_reshaped   =  k.reshape(-1, 1)
    c            =  k_reshaped ** alpha + (1 - delta) * k_reshaped - k
    c[c<0]       =  1e-11
    c            =  np.log(c)
    beta_square  =  beta**2

    # multiplication
    I   =  np.identity(T)
    E   =  np.ones(T)[:,None]
    Q2  =  I

    J = np.empty((T, T))
    tmp0 = np.empty((T, T))
    tmp1 = np.empty((T, T))
    tmp2 = np.empty((T, 1))
    tmp3 = np.empty((T, 1))
    tmp4 = np.empty((T, T))
    B = np.empty((T, T))

    while np.any(dif > tol) and loop < 200:
        np.multiply(beta, Q2, out=J)
        np.subtract(I, J, out=tmp0)
        invert_sparse_matrix(tmp0, B)

        Q3 = np.zeros((T,T))
        np.multiply(J, c, out=tmp1)
        np.matmul(tmp1, E, out=tmp2)
        np.matmul(B, tmp2, out=tmp3)
        ini = compute_ini(c, tmp3.flatten())
        Q3[np.arange(T), ini] = 1

        factor = 0.1 * 2 / beta_square
        sparse_square_premult(tmp0, c, tmp4)
        np.add(B, (factor * tmp3) @ (tmp2 + B @ (tmp4 @ E)).T, out=B)

        dif = max_abs_diff(Q3, Q2)
        kcQ = k[ini]

        Q2 = Q3
        loop += 1

compute_numba()

Note that the invert_sparse_matrix function still takes a significant time (>50% on my machine) though it is about 3x faster than inv and about as fast as solve. It is an improvement of the naive inversion algorithm with few optimizations for very sparse matrices. It can certainly be optimized further (eg. using tiling) but this is certainly not trivial to do (especially for novice programmers).

Note the compilation time takes few seconds.

Overall, this implementation is about 4~5 times faster than the initial one on my i5-9600KF processor (6 cores).

Answered By: Jérôme Richard