Is it possible to improve python performance for this code?

Question:

I have a simple code that:

Read a trajectory file that can be seen as a list of 2D arrays (list of positions in space) stored in Y

I then want to compute for each pair (scipy.pdist style) the RMSD

My code works fine:

trajectory = read("test.lammpstrj", index="::")
m = len(trajectory)
#.get_positions() return a 2d numpy array
Y = np.array([snapshot.get_positions() for snapshot in trajectory])

b = [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]

This code execute in 0.86 seconds using python3.10, using Julia1.8 the same kind of code execute in 0.46 seconds

I plan to have trajectory much larger (~ 200,000 elements), would it be possible to get a speed-up using python or should I stick to Julia?

Asked By: Okano

||

Answers:

Stick to Julia.

If you already made it in a language which runs faster, why are you trying to use python in the first place?

Answered By: Craze XD

You’ve mentioned that snapshot.get_positions() returns some 2D array, suppose of shape (p, q). So I expect that Y is a 3D array with some shape (m, p, q), where m is the number of snapshots in the trajectory. You also expect m to scale rather high.

Let’s see a basic way to speed up the distance calculation, on the setting m=1000:

import numpy as np

# dummy inputs
m = 1000
p, q = 4, 5
Y = np.random.randn(m, p, q)

# your current method
def foo():
    return [np.sqrt(((((Y[i]- Y[j])**2))*3).mean()) for i in range(m) for j in range(i + 1, m)]

# vectorized approach -> compute the upper triangle of the pairwise distance matrix
def bar():
    u, v = np.triu_indices(Y.shape[0], 1)
    return np.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))

# Check for correctness

out_1 = foo()
out_2 = bar()
print(np.allclose(out_1, out_2))
# True

If we test the time required:

%timeit -n 10 -r 3 foo()
# 3.16 s ± 50.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

The first method is really slow, it takes over 3 seconds for this calculation. Let’s check the second method:

%timeit -n 10 -r 3 bar()
# 97.5 ms ± 405 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

So we have a ~30x speedup here, which would make your large calculation in python much more feasible than using the original code. Feel free to test out with other sizes of Y to see how it scales compared to the original.


JIT

In addition, you can also try out JIT, mainly jax or numba. It is fairly simple to port the function bar with jax.numpy, for example:

import jax
import jax.numpy as jnp

@jax.jit
def jit_bar(Y):
    u, v = jnp.triu_indices(Y.shape[0], 1)
    return jnp.sqrt((3 * (Y[u] - Y[v]) ** 2).mean(axis=(-1, -2)))

# check for correctness

print(np.allclose(bar(), jit_bar(Y)))
# True

If we test the time of the jitted jnp op:

%timeit -n 10 -r 3 jit_bar(Y)
# 10.6 ms ± 678 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)

So compared to the original, we could reach even up to ~300x speed.

Note that not every operation can be converted to jax/jit so easily (this particular problem is conveniently suitable), so the general advice is to simply avoid python loops and use numpy’s broadcasting/vectorization capabilities, like in bar().

Answered By: Mercury

Your question is about speeding up Python, relative to Julia, so I’d like to offer some Julia code for comparison.

Since your data is most naturally expressed as a list of 4×5 arrays, I suggest expressing it as a vector of SMatrixes:

sumdiff2(A, B) = sum((A[i] - B[i])^2 for i in eachindex(A, B))
function dists(Y)
    M = length(Y)
    V = Vector{float(eltype(eltype(Y)))}(undef, sum(1:M-1))
    Threads.@threads for i in eachindex(Y)
        ii = sum(M-i+1:M-1)  # don't worry about this sum
        for j in i+1:lastindex(Y)
            ind = ii + (j-i)
            V[ind] = sqrt(3 * sumdiff2(Y[i], Y[j])/length(Y[i]))
        end
    end
    return V
end

using Random: randn
using StaticArrays: SMatrix
Ys = [randn(SMatrix{4,5,Float64}) for _ in 1:1000];

Benchmarks:

# single-threaded
julia> using BenchmarkTools
julia> @btime dists($Ys);
  6.561 ms (2 allocations: 3.81 MiB)

# multi-threaded with 6 cores
julia> @btime dists($Ys);
  1.606 ms (75 allocations: 3.82 MiB)

I was not able to install jax on my computer, but when comparing with @Mercury’s numpy code I got

foo: 5.5seconds
bar: 179ms

i.e. approximately 3400x speedup over foo.

It is possible to write this as a one-liner at a ~2-3x performance cost.

Answered By: DNF

While Python tends to be slower than Julia for many tasks, it is possible to write numerical codes as fast as Julia in Python using Numba and plain loops. Indeed, Numba is based on LLVM-Lite which is basically a JIT-compiler based on the LLVM toolchain. The standard implementation of Julia also use a JIT and the LLVM toolchain. This means the two should behave pretty closely besides the overhead introduced by the languages that are negligible once the computation is performed in parallel (because the resulting computation will be memory-bound on nearly all modern platforms).

This computation can be parallelized in both Julia and Python (still using Numba). While writing a sequential computation is quite straightforward, writing a parallel computation is if bit more complex. Indeed, computing the upper triangular values can result in an imbalanced workload and so to a sub-optimal execution time. An efficient strategy is to compute, for each iteration, a pair of lines: one comes from the top of the upper triangular part and one comes from the bottom. The top line contains m-i items while the bottom one contains i+1 items. In the end, there is m+1 items to compute per iteration so the number of item is independent of the iteration number. This results in a much better load-balancing. The line of the middle needs to be computed separately regarding the size of the input array.

Here is the final implementation:

import numba as nb
import numpy as np

@nb.njit(inline='always', fastmath=True)
def compute_line(tmp, res, i, m):
    offset = (i * (2 * m - i - 1)) // 2
    factor = 3.0 / n
    for j in range(i + 1, m):
        s = 0.0
        for k in range(n):
            s += (tmp[i, k] - tmp[j, k]) ** 2
        res[offset] = np.sqrt(s * factor)
        offset += 1
    return res

@nb.njit('()', parallel=True, fastmath=True)
def fastest():
    m, n = Y.shape[0], Y.shape[1] * Y.shape[2]
    res = np.empty(m*(m-1)//2)
    tmp = Y.reshape(m, n)
    for i in nb.prange(m//2):
        compute_line(tmp, res, i, m)
        compute_line(tmp, res, m-i-1, m)
    if m % 2 == 1:
        compute_line(tmp, res, (m+1)//2, m)
    return res

# [...] same as others
%timeit -n 100 fastest()

Results

Here are performance results on my machine (with a i5-9600KF having 6 cores):

foo     (seq, Python, Mercury):    4910.7 ms
bar     (seq, Python, Mercury):     134.2 ms
jit_bar (seq, Python, Mercury):       ???
dists   (seq, Julia,  DNF)            6.9 ms
dists   (par, Julia,  DNF)            2.2 ms
fastest (par, Python, me):            1.5 ms  <-----

(Jax does not work on my machine so I cannot test it yet)

This implementation is the fastest one and succeed to beat the best Julia code so far.


Optimal implementation

Note that for large arrays like (200_000,4,5), all implementations provided so far are inefficient since they are not cache friendly. Indeed, the input array will take 32 MiB and will not for on the cache of most modern processors (and even if it could, one need to consider the space needed for the output and the fact that caches are not perfect). This can be fixed using tiling, at the expense of an even more complex code. I think such an implementation should be optimal if you use Z-order curves.

Answered By: Jérôme Richard
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.