zip iterators asserting for equal length in python

Question:

I am looking for a nice way to zip several iterables raising an exception if the lengths of the iterables are not equal.

In the case where the iterables are lists or have a len method this solution is clean and easy:

def zip_equal(it1, it2):
    if len(it1) != len(it2):
        raise ValueError("Lengths of iterables are different")
    return zip(it1, it2)

However, if it1 and it2 are generators, the previous function fails because the length is not defined TypeError: object of type 'generator' has no len().

I imagine the itertools module offers a simple way to implement that, but so far I have not been able to find it. I have come up with this home-made solution:

def zip_equal(it1, it2):
    exhausted = False
    while True:
        try:
            el1 = next(it1)
            if exhausted: # in a previous iteration it2 was exhausted but it1 still has elements
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            exhausted = True
            # it2 must be exhausted too.
        try:
            el2 = next(it2)
            # here it2 is not exhausted.
            if exhausted:  # it1 was exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
        except StopIteration:
            # here it2 is exhausted
            if not exhausted:
                # but it1 was not exhausted => raise
                raise ValueError("it1 and it2 have different lengths")
            exhausted = True
        if not exhausted:
            yield (el1, el2)
        else:
            return

The solution can be tested with the following code:

it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it1, it2))           # len(it1) < len(it2) => raise
it1 = (x for x in ['a', 'b', 'c'])  # it1 has length 3
it2 = (x for x in [0, 1, 2, 3])     # it2 has length 4
list(zip_equal(it2, it1))           # len(it2) > len(it1) => raise
it1 = (x for x in ['a', 'b', 'c', 'd'])  # it1 has length 4
it2 = (x for x in [0, 1, 2, 3])          # it2 has length 4
list(zip_equal(it1, it2))                # like zip (or izip in python2)

Am I overlooking any alternative solution? Is there a simpler implementation of my zip_equal function?

Update:

  • Requiring python 3.10 or newer, see Asocia’s answer
  • Thorough performance benchmarking and best performing solution on python<3.10: Stefan’s answer
  • Simple answer without external dependencies:
    Martijn Pieters’ answer (please check the comments for a bugfix in some corner cases)
  • More complex than Martijn’s, but with better performance: cjerdonek’s answer
  • If you don’t mind a package dependency, see
    pylang’s answer
Asked By: zeehio

||

Answers:

I can think of a simpler solution, use itertools.zip_longest() and raise an exception if the sentinel value used to pad out shorter iterables is present in the tuple produced:

from itertools import zip_longest

def zip_equal(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

Unfortunately, we can’t use zip() with yield from to avoid a Python-code loop with a test each iteration; once the shortest iterator runs out, zip() would advance all preceding iterators and thus swallow the evidence if there is but one extra item in those.

Answered By: Martijn Pieters

Here is an approach that doesn’t require doing any extra checks with each loop of the iteration. This could be desirable especially for long iterables.

The idea is to pad each iterable with a “value” at the end that raises an exception when reached, and then do the needed verification only at the very end. The approach uses zip() and itertools.chain().

The code below was written for Python 3.5.

import itertools

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None

Below is what it looks like being used.

>>> list(zip_equal([1, 2], [3, 4], [5, 6]))
[(1, 3, 5), (2, 4, 6)]

>>> list(zip_equal([1, 2], [3], [4]))
RuntimeError: iterable 1 exhausted first

>>> list(zip_equal([1], [2, 3], [4]))
RuntimeError: iterable 1 is longer

>>> list(zip_equal([1], [2], [3, 4]))
RuntimeError: iterable 2 is longer
Answered By: cjerdonek

I came up with a solution using sentinel iterable FYI:

class _SentinelException(Exception):
    def __iter__(self):
        raise _SentinelException


def zip_equal(iterable1, iterable2):
    i1 = iter(itertools.chain(iterable1, _SentinelException()))
    i2 = iter(iterable2)
    try:
        while True:
            yield (next(i1), next(i2))
    except _SentinelException:  # i1 reaches end
        try:
            next(i2)  # check whether i2 reaches end
        except StopIteration:
            pass
        else:
            raise ValueError('the second iterable is longer than the first one')
    except StopIteration: # i2 reaches end, as next(i1) has already been called, i1's length is bigger than i2
        raise ValueError('the first iterable is longger the second one.')
Answered By: XU Weijiang

Use more_itertools.zip_equal (v8.3.0+):

Code

import more_itertools as mit

Demo

list(mit.zip_equal(range(3), "abc"))
# [(0, 'a'), (1, 'b'), (2, 'c')]

list(mit.zip_equal(range(3), "abcd"))
# UnequalIterablesError

more_itertools is a third party package installed via λ pip install more_itertools

Answered By: pylang

An optional boolean keyword argument, strict, is introduced for the built-in zip function in PEP 618.

Quoting What’s New In Python 3.10:

The zip() function now has an optional strict flag, used to require that all the iterables have an equal length.

When enabled, a ValueError is raised if one of the arguments is exhausted before the others.

>>> list(zip('ab', range(3)))
[('a', 0), ('b', 1)]
>>> list(zip('ab', range(3), strict=True))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: zip() argument 2 is longer than argument 1
Answered By: Asocia

A new solution even much faster than cjerdonek’s on which it’s based, and a benchmark. Benchmark first, my solution is green. Note that the "total size" is the same in all cases, two million values. The x-axis is the number of iterables. From 1 iterable with two million values, then 2 iterables with a million values each, all the way up to 100,000 iterables with 20 values each.

benchmark plot

The black one is Python’s zip, I used Python 3.8 here so it doesn’t do this question’s task of checking for equal lengths, but I include it as reference/limit of the maximum speed one can hope for. You can see my solution is pretty close.

For the perhaps most common case of zipping two iterables, mine’s almost three times as fast as the previousy fastest solution by cjerdonek, and not much slower than zip. Times as text:

         number of iterables     1     2     3     4     5    10   100  1000 10000 50000 100000
-----------------------------------------------------------------------------------------------
       more_itertools__pylang 209.3 132.1 105.8  93.7  87.4  74.4  54.3  51.9  53.9  66.9  84.5
   fillvalue__Martijn_Pieters 159.1 101.5  85.6  74.0  68.8  59.0  44.1  43.0  44.9  56.9  72.0
     chain_raising__cjerdonek  58.5  35.1  26.3  21.9  19.7  16.6  10.4  12.7  34.4 115.2 223.2
     ziptail__Stefan_Pochmann  10.3  12.4  10.4   9.2   8.7   7.8   6.7   6.8   9.4  22.6  37.8
                          zip  10.3   8.5   7.8   7.4   7.4   7.1   6.4   6.8   9.0  19.4  32.3

My code (Try it online!):

def zip_equal(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError('zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError('zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

The basic idea is to let zip(*iterables) do all the work, and then after it stopped because some iterable was exhausted, check whether all iterables were equally long. They were if and only if:

  1. zip stopped because the first iterable didn’t have another elements (i.e., no other iterable is shorter).
  2. None of the other iterables have any further elements (i.e., no other iterable is longer).

How I check these criteria:

  • Since I need to check these criteria after zip ended, I can’t return the zip object purely. Instead, I chain an empty zip_tail iterator behind it that does the checking.
  • To support checking the first criterion, I chain an empty first_tail iterator behind it whose sole job is to log that the first iterable’s iteration stopped (i.e., it was asked for another element and it didn’t have one, so the first_tail iterator was asked for one instead).
  • To support checking the second criterion, I fetch iterators for all the other iterables and keep them in a list before I give them to zip.

Side note: more-itertools pretty much uses the same method as Martijn’s, but does proper is checks instead of Martijn’s not quite correct sentinel in combo. That’s probably the main reason it’s slower.

Benchmark code (Try it online!):

import timeit
import itertools
from itertools import repeat, chain, zip_longest
from collections import deque
from sys import hexversion, maxsize

#-----------------------------------------------------------------------------
# Solution by Martijn Pieters
#-----------------------------------------------------------------------------

def zip_equal__fillvalue__Martijn_Pieters(*iterables):
    sentinel = object()
    for combo in zip_longest(*iterables, fillvalue=sentinel):
        if sentinel in combo:
            raise ValueError('Iterables have different lengths')
        yield combo

#-----------------------------------------------------------------------------
# Solution by pylang
#-----------------------------------------------------------------------------

def zip_equal__more_itertools__pylang(*iterables):
    return more_itertools__zip_equal(*iterables)

_marker = object()

def _zip_equal_generator(iterables):
    for combo in zip_longest(*iterables, fillvalue=_marker):
        for val in combo:
            if val is _marker:
                raise UnequalIterablesError()
        yield combo

def more_itertools__zip_equal(*iterables):
    """``zip`` the input *iterables* together, but raise
    ``UnequalIterablesError`` if they aren't all the same length.

        >>> it_1 = range(3)
        >>> it_2 = iter('abc')
        >>> list(zip_equal(it_1, it_2))
        [(0, 'a'), (1, 'b'), (2, 'c')]

        >>> it_1 = range(3)
        >>> it_2 = iter('abcd')
        >>> list(zip_equal(it_1, it_2)) # doctest: +IGNORE_EXCEPTION_DETAIL
        Traceback (most recent call last):
        ...
        more_itertools.more.UnequalIterablesError: Iterables have different
        lengths

    """
    if hexversion >= 0x30A00A6:
        warnings.warn(
            (
                'zip_equal will be removed in a future version of '
                'more-itertools. Use the builtin zip function with '
                'strict=True instead.'
            ),
            DeprecationWarning,
        )
    # Check whether the iterables are all the same size.
    try:
        first_size = len(iterables[0])
        for i, it in enumerate(iterables[1:], 1):
            size = len(it)
            if size != first_size:
                break
        else:
            # If we didn't break out, we can use the built-in zip.
            return zip(*iterables)

        # If we did break out, there was a mismatch.
        raise UnequalIterablesError(details=(first_size, i, size))
    # If any one of the iterables didn't have a length, start reading
    # them until one runs out.
    except TypeError:
        return _zip_equal_generator(iterables)

#-----------------------------------------------------------------------------
# Solution by cjerdonek
#-----------------------------------------------------------------------------

class ExhaustedError(Exception):
    def __init__(self, index):
        """The index is the 0-based index of the exhausted iterable."""
        self.index = index

def raising_iter(i):
    """Return an iterator that raises an ExhaustedError."""
    raise ExhaustedError(i)
    yield

def terminate_iter(i, iterable):
    """Return an iterator that raises an ExhaustedError at the end."""
    return itertools.chain(iterable, raising_iter(i))

def zip_equal__chain_raising__cjerdonek(*iterables):
    iterators = [terminate_iter(*args) for args in enumerate(iterables)]
    try:
        yield from zip(*iterators)
    except ExhaustedError as exc:
        index = exc.index
        if index > 0:
            raise RuntimeError('iterable {} exhausted first'.format(index)) from None
        # Check that all other iterators are also exhausted.
        for i, iterator in enumerate(iterators[1:], start=1):
            try:
                next(iterator)
            except ExhaustedError:
                pass
            else:
                raise RuntimeError('iterable {} is longer'.format(i)) from None
            
#-----------------------------------------------------------------------------
# Solution by Stefan Pochmann
#-----------------------------------------------------------------------------

def zip_equal__ziptail__Stefan_Pochmann(*iterables):

    # For trivial cases, use pure zip.
    if len(iterables) < 2:
        return zip(*iterables)

    # Tail for the first iterable
    first_stopped = False
    def first_tail():
        nonlocal first_stopped 
        first_stopped = True
        return
        yield

    # Tail for the zip
    def zip_tail():
        if not first_stopped:
            raise ValueError(f'zip_equal: first iterable is longer')
        for _ in chain.from_iterable(rest):
            raise ValueError(f'zip_equal: first iterable is shorter')
            yield

    # Put the pieces together
    iterables = iter(iterables)
    first = chain(next(iterables), first_tail())
    rest = list(map(iter, iterables))
    return chain(zip(first, *rest), zip_tail())

#-----------------------------------------------------------------------------
# List of solutions to be speedtested
#-----------------------------------------------------------------------------

solutions = [
    zip_equal__more_itertools__pylang,
    zip_equal__fillvalue__Martijn_Pieters,
    zip_equal__chain_raising__cjerdonek,
    zip_equal__ziptail__Stefan_Pochmann,
    zip,
]

def name(solution):
    return solution.__name__[11:] or 'zip'

#-----------------------------------------------------------------------------
# The speedtest code
#-----------------------------------------------------------------------------

def test(m, n):
    """Speedtest all solutions with m iterables of n elements each."""

    all_times = {solution: [] for solution in solutions}
    def show_title():
        print(f'{m} iterators of length {n:,}:')
    if verbose: show_title()
    def show_times(times, solution):
        print(*('%3d ms ' % t for t in times),
              name(solution))
        
    for _ in range(3):
        for solution in solutions:
            times = sorted(timeit.repeat(lambda: deque(solution(*(repeat(i, n) for i in range(m))), 0), number=1, repeat=5))[:3]
            times = [round(t * 1e3, 3) for t in times]
            all_times[solution].append(times)
            if verbose: show_times(times, solution)
        if verbose: print()
        
    if verbose:
        print('best by min:')
        show_title()
        for solution in solutions:
            show_times(min(all_times[solution], key=min), solution)
        print('best by max:')
    show_title()
    for solution in solutions:
        show_times(min(all_times[solution], key=max), solution)
    print()

    stats.append((m,
                  [min(all_times[solution], key=min)
                   for solution in solutions]))

#-----------------------------------------------------------------------------
# Run the speedtest for several numbers of iterables
#-----------------------------------------------------------------------------

stats = []
verbose = False
total_elements = 2 * 10**6
for m in 1, 2, 3, 4, 5, 10, 100, 1000, 10000, 50000, 100000:
    test(m, total_elements // m)

#-----------------------------------------------------------------------------
# Print the speedtest results for use in the plotting script
#-----------------------------------------------------------------------------

print('data for plotting by https://replit.com/@pochmann/zipequal-plot')
names = [name(solution) for solution in solutions]
print(f'{names = }')
print(f'{stats = }')

Code for plotting/table (also at Replit):

import matplotlib.pyplot as plt

names = ['more_itertools__pylang', 'fillvalue__Martijn_Pieters', 'chain_raising__cjerdonek', 'ziptail__Stefan_Pochmann', 'zip']
stats = [(1, [[208.762, 211.211, 214.189], [159.568, 162.233, 162.24], [57.668, 58.94, 59.23], [10.418, 10.583, 10.723], [10.057, 10.443, 10.456]]), (2, [[130.065, 130.26, 130.52], [100.314, 101.206, 101.276], [34.405, 34.998, 35.188], [12.152, 12.473, 12.773], [8.671, 8.857, 9.395]]), (3, [[106.417, 107.452, 107.668], [90.693, 91.154, 91.386], [26.908, 27.863, 28.145], [10.457, 10.461, 10.789], [8.071, 8.157, 8.228]]), (4, [[97.547, 98.686, 98.726], [77.076, 78.31, 79.381], [23.134, 23.176, 23.181], [9.321, 9.4, 9.581], [7.541, 7.554, 7.635]]), (5, [[86.393, 88.046, 88.222], [68.633, 69.649, 69.742], [19.845, 20.006, 20.135], [8.726, 8.935, 9.016], [7.201, 7.26, 7.304]]), (10, [[70.384, 71.762, 72.473], [57.87, 58.149, 58.411], [15.808, 16.252, 16.262], [7.568, 7.57, 7.864], [6.732, 6.888, 6.911]]), (100, [[53.108, 54.245, 54.465], [44.436, 44.601, 45.226], [10.502, 11.073, 11.109], [6.721, 6.733, 6.847], [6.753, 6.774, 6.815]]), (1000, [[52.119, 52.476, 53.341], [42.775, 42.808, 43.649], [12.538, 12.853, 12.862], [6.802, 6.971, 7.002], [6.679, 6.724, 6.838]]), (10000, [[54.802, 55.006, 55.187], [45.981, 46.066, 46.735], [34.416, 34.672, 35.009], [9.485, 9.509, 9.626], [9.036, 9.042, 9.112]]), (50000, [[66.681, 66.98, 67.441], [56.593, 57.341, 57.631], [113.988, 114.022, 114.106], [22.088, 22.412, 22.595], [19.412, 19.431, 19.934]]), (100000, [[86.846, 88.111, 88.258], [74.796, 75.431, 75.927], [218.977, 220.182, 223.343], [38.89, 39.385, 39.88], [32.332, 33.117, 33.594]])]

colors = {
    'more_itertools__pylang': 'm',
    'fillvalue__Martijn_Pieters': 'red',
    'chain_raising__cjerdonek': 'gold',
    'ziptail__Stefan_Pochmann': 'lime',
    'zip': 'black',
}

ns = [n for n, _ in stats]
print('%28s' % 'number of iterables', *('%5d' % n for n in ns))
print('-' * 95)
x = range(len(ns))
for i, name in enumerate(names):
    ts = [min(tss[i]) for _, tss in stats]
    color = colors[name]
    if color:
        plt.plot(x, ts, '.-', color=color, label=name)
        print('%29s' % name, *('%5.1f' % t for t in ts))
plt.xticks(x, ns, size=9)
plt.ylim(0, 133)
plt.title('zip_equal(m iterables with 2,000,000/m values each)', weight='bold')
plt.xlabel('Number of zipped *iterables* (not their lengths)', weight='bold')
plt.ylabel('Time (for complete iteration) in milliseconds', weight='bold')
plt.legend(loc='upper center')
#plt.show()
plt.savefig('zip_equal_plot.png', dpi=200)
Answered By: Stefan Pochmann
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.