Sum of List Comprehension Not Running Fast Enough

Question:

I have three lists that I am using a sum of a list comprehension, however, having these lists of length n >= 1500 I’ve been unable to make my code run any more efficiently than ~3s per list comprehension. This code needs to run thousands of times, so 3s per does not cut it.

Below is what my current attempt looks like. The split is just a float determined earlier in my code.

sum([list1[k] * (list2[k] == 1) if list3[k] < split else list1[k] * (list2[k] == -1) for k in range(n)])

list1 contains 1500 positive floats between 0 and 1 which sum to 1.

list2 contains 1500 randomly sampled -1’s and 1’s.

list3 contains 1500 randomly sampled values from a normal distribution, an example would be np.random.normal(5, 0.5, 3).

Asked By: namor129

||

Answers:

I ended up writing three approaches to your question: improved python, numpy, and numba.

  • The improved python version based on @KellyBelly’s comment works nicely. zip has a surprisingly strong effect on performance here.

  • With numpy, you want to leverage the power of vectorized operations, turn your conditions into masks and get rid of loops entirely.

  • Numba is usually the fastest solution if you feel at ease with its important concepts (njit, prange, etc.). It takes a bit more proof-reading than the numpy approach but it’s well rewarded.

Note that those are only different ways of implementing the same algorithm. Improving an inefficient algorithm is very important too if you are chasing those precious milliseconds.

Timings:

Items List comprehension Zipped iterator Numpy arrays Numba.njit Numba.njit(parallel=True)
1 k 0.191 ms 0.129 ms 0.487 ms 0.006 ms 0.013 ms
10 k 2.288 ms 1.206 ms 0.477 ms 0.048 ms 0.019 ms
100 k 18.941 ms 13.245 ms 2.857 ms 0.477 ms 0.056 ms

Code:

# Imports.
import numba as nb
import numpy as np
np.random.seed(0)

# Data.
N = 100000
SPLIT = 50
array1 = np.random.randint(0, 100, N)
array2 = np.random.choice((+1, -1), N)
array3 = np.random.randint(0, 100, N)
list1, list2, list3 = map(lambda a: a.tolist(), (array1, array2, array3))
print(N)

# Helpful timing function.
from contextlib import contextmanager
import time

@contextmanager
def time_this():
    t0 = time.perf_counter()
    yield
    dt = time.perf_counter() - t0
    print(f"{dt*1000:.3f} ms")

# List comprehension.
def list_comprehension():
    n = len(list1)
    return sum([list1[k] * (list2[k] == 1) if list3[k] < SPLIT else list1[k] * (list2[k] == -1) for k in range(n)])

# Zipped iterator.
def zipped_iterator():
    return sum(l1 if l2 == (1 if l3 < SPLIT else -1) else 0 for l1, l2, l3 in zip(list1, list2, list3))

# Numpy array.
def numpy_arrays():
    mask = array3 < SPLIT
    positives = array1[mask] * (array2[mask] == 1)
    negatives = array1[~mask] * (array2[~mask] == -1)
    return positives.sum() + negatives.sum()

# Numba.
@nb.njit
def numba_count():
    total = 0
    n = len(array1)
    for k in nb.prange(n):
        if array3[k] < SPLIT:
            sign = +1
        else:
            sign = -1
        if array2[k] == sign:
              total += array1[k]
    return total

# Numba in parallel.
@nb.njit(parallel=True)
def numba_count2():
    total = 0
    n = len(array1)
    for k in nb.prange(n):
        if array3[k] < SPLIT:
            sign = +1
        else:
            sign = -1
        if array2[k] == sign:
              total += array1[k]
    return total

# Timings.
totals = []
with time_this():
    totals.append(list_comprehension())

with time_this():
    totals.append(zipped_iterator())

with time_this():
    totals.append(numpy_arrays())

numba_count() # Compile before we time anything.
with time_this():
    totals.append(numba_count())

numba_count2() # Compile before we time anything.
with time_this():
    totals.append(numba_count2())

# Assert that all the returned values are identical.
assert np.isclose(totals, totals[0]).all()
Answered By: Guimoute
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.