Efficiently count zero elements in numpy array?

Question:

I need to count the number of zero elements in numpy arrays. I’m aware of the numpy.count_nonzero function, but there appears to be no analog for counting zero elements.

My arrays are not very large (typically less than 1E5 elements) but the operation is performed several millions of times.

Of course I could use len(arr) - np.count_nonzero(arr), but I wonder if there’s a more efficient way to do it.

Here’s a MWE of how I do it currently:

import numpy as np
import timeit

arrs = []
for _ in range(1000):
    arrs.append(np.random.randint(-5, 5, 10000))


def func1():
    for arr in arrs:
        zero_els = len(arr) - np.count_nonzero(arr)


print(timeit.timeit(func1, number=10))
Asked By: Gabriel

||

Answers:

A 2x faster approach would be to just use np.count_nonzero() but with the condition as needed.

In [3]: arr
Out[3]: 
array([[1, 2, 0, 3],
      [3, 9, 0, 4]])

In [4]: np.count_nonzero(arr==0)
Out[4]: 2

In [5]:def func_cnt():
            for arr in arrs:
                zero_els = np.count_nonzero(arr==0)
                # here, it counts the frequency of zeroes actually

You can also use np.where() but it’s slower than np.count_nonzero()

In [6]: np.where( arr == 0)
Out[6]: (array([0, 1]), array([2, 2]))

In [7]: len(np.where( arr == 0))
Out[7]: 2

Efficiency: (in descending order)

In [8]: %timeit func_cnt()
10 loops, best of 3: 29.2 ms per loop

In [9]: %timeit func1()
10 loops, best of 3: 46.5 ms per loop

In [10]: %timeit func_where()
10 loops, best of 3: 61.2 ms per loop

more speedups with accelerators

It is now possible to achieve more than 3 orders of magnitude speed boost with the help of JAX if you’ve access to accelerators (GPU/TPU). Another advantage of using JAX is that the NumPy code needs very little modification to make it JAX compatible. Below is a reproducible example:

In [1]: import jax.numpy as jnp
In [2]: from jax import jit

# set up inputs
In [3]: arrs = []
In [4]: for _ in range(1000):
   ...:     arrs.append(np.random.randint(-5, 5, 10000))

# JIT'd function that performs the counting task
In [5]: @jit
   ...: def func_cnt():
   ...:     for arr in arrs:
   ...:         zero_els = jnp.count_nonzero(arr==0)

# efficiency test
In [8]: %timeit func_cnt()
15.6 µs ± 391 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Answered By: kmario23