How do I add a skip condition to a for loop without slowing it down?
Question:
I am writing a function which processes a big list of elements, and want to add the functionality that certain elements are skipped depending on an input argument. Here’s what the code looks like:
def f(lst, skipsEnabled=False):
for elem in lst:
if skipsEnabled and skipCheck(elem):
continue
else:
# do stuff to elem here
When skipsEnabled is false, the if skipsEnabled
conditional can be ignored every time, however it still gets checked on every iteration and it slows my loop down. I could put the conditional before the loop, but then I’d need to duplicate the loop and this is not very OOPy since all the elem
processing code would be copy-pasted in both loops.
Is there a clean way to do this?
Answers:
Define a method to do the element processing code and then have the loop with or without the conditional and then call the same method for the processing.
It is highly unlikely that you are actually generating a large efficiency saving here, though. The skipsEnabled and skipCheck(elem)
will short circuit and never call skipCheck()
so it’s a single boolean check that always branches the same way. That will optimise well.
You could use filterfalse()
from itertools to filter out any elements which don’t satisfy the condition, and loop over those:
from itertools import filterfalse
def f(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
# do stuff to elem here
Benchmark with 100,000 elements, every second getting skipped if skipping is enabled:
skipsEnabled = True
9.82 ± 0.29 ms Kelly
14.50 ± 0.18 ms original
19.48 ± 0.72 ms B_Remmelzwaal
19.62 ± 0.41 ms Tim
skipsEnabled = False
1.52 ± 0.03 ms Kelly
1.54 ± 0.02 ms B_Remmelzwaal
2.16 ± 0.02 ms original
8.59 ± 0.29 ms Tim
The fastest is a variation of B Remmelzwaal’s, but using itertools.filterfalse
:
from itertools import filterfalse
def f(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
elem
Tim’s is how I imagine they meant it, as they provided no code.
Benchmark code (Attempt This Online!):
from timeit import timeit
from itertools import filterfalse
from statistics import mean, stdev
def original(lst, skipsEnabled=False):
for elem in lst:
if skipsEnabled and skipCheck(elem):
continue
else:
elem
def Tim(lst, skipsEnabled=False):
def process(elem):
elem
if skipsEnabled:
for elem in lst:
if not skipCheck(elem):
process(elem)
else:
for elem in lst:
process(elem)
def B_Remmelzwaal(lst, skipsEnabled=False):
if skipsEnabled:
lst = filter(lambda x: not skipCheck(x), lst)
for elem in lst:
elem
def Kelly(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
elem
funcs = original, Tim, B_Remmelzwaal, Kelly
def skipCheck(n):
return n % 2
lst = [True, False] * 10**5
lst = list(range(10**5))
for skipsEnabled in True, False:
print(f'{skipsEnabled = }')
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(25):
for f in funcs:
t = timeit(lambda: f(lst, skipsEnabled), number=1)
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print()
I am writing a function which processes a big list of elements, and want to add the functionality that certain elements are skipped depending on an input argument. Here’s what the code looks like:
def f(lst, skipsEnabled=False):
for elem in lst:
if skipsEnabled and skipCheck(elem):
continue
else:
# do stuff to elem here
When skipsEnabled is false, the if skipsEnabled
conditional can be ignored every time, however it still gets checked on every iteration and it slows my loop down. I could put the conditional before the loop, but then I’d need to duplicate the loop and this is not very OOPy since all the elem
processing code would be copy-pasted in both loops.
Is there a clean way to do this?
Define a method to do the element processing code and then have the loop with or without the conditional and then call the same method for the processing.
It is highly unlikely that you are actually generating a large efficiency saving here, though. The skipsEnabled and skipCheck(elem)
will short circuit and never call skipCheck()
so it’s a single boolean check that always branches the same way. That will optimise well.
You could use filterfalse()
from itertools to filter out any elements which don’t satisfy the condition, and loop over those:
from itertools import filterfalse
def f(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
# do stuff to elem here
Benchmark with 100,000 elements, every second getting skipped if skipping is enabled:
skipsEnabled = True
9.82 ± 0.29 ms Kelly
14.50 ± 0.18 ms original
19.48 ± 0.72 ms B_Remmelzwaal
19.62 ± 0.41 ms Tim
skipsEnabled = False
1.52 ± 0.03 ms Kelly
1.54 ± 0.02 ms B_Remmelzwaal
2.16 ± 0.02 ms original
8.59 ± 0.29 ms Tim
The fastest is a variation of B Remmelzwaal’s, but using itertools.filterfalse
:
from itertools import filterfalse
def f(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
elem
Tim’s is how I imagine they meant it, as they provided no code.
Benchmark code (Attempt This Online!):
from timeit import timeit
from itertools import filterfalse
from statistics import mean, stdev
def original(lst, skipsEnabled=False):
for elem in lst:
if skipsEnabled and skipCheck(elem):
continue
else:
elem
def Tim(lst, skipsEnabled=False):
def process(elem):
elem
if skipsEnabled:
for elem in lst:
if not skipCheck(elem):
process(elem)
else:
for elem in lst:
process(elem)
def B_Remmelzwaal(lst, skipsEnabled=False):
if skipsEnabled:
lst = filter(lambda x: not skipCheck(x), lst)
for elem in lst:
elem
def Kelly(lst, skipsEnabled=False):
if skipsEnabled:
lst = filterfalse(skipCheck, lst)
for elem in lst:
elem
funcs = original, Tim, B_Remmelzwaal, Kelly
def skipCheck(n):
return n % 2
lst = [True, False] * 10**5
lst = list(range(10**5))
for skipsEnabled in True, False:
print(f'{skipsEnabled = }')
times = {f: [] for f in funcs}
def stats(f):
ts = [t * 1e3 for t in sorted(times[f])[:5]]
return f'{mean(ts):6.2f} ± {stdev(ts):4.2f} ms '
for _ in range(25):
for f in funcs:
t = timeit(lambda: f(lst, skipsEnabled), number=1)
times[f].append(t)
for f in sorted(funcs, key=stats):
print(stats(f), f.__name__)
print()