split a generator/iterable every n items in python (splitEvery)

Question:

I’m trying to write the Haskell function ‘splitEvery’ in Python. Here is it’s definition:

splitEvery :: Int -> [e] -> [[e]]
    @'splitEvery' n@ splits a list into length-n pieces.  The last
    piece will be shorter if @n@ does not evenly divide the length of
    the list.

The basic version of this works fine, but I want a version that works with generator expressions, lists, and iterators. And, if there is a generator as an input it should return a generator as an output!

Tests

# should not enter infinite loop with generators or lists
splitEvery(itertools.count(), 10)
splitEvery(range(1000), 10)

# last piece must be shorter if n does not evenly divide
assert splitEvery(5, range(9)) == [[0, 1, 2, 3, 4], [5, 6, 7, 8]]

# should give same correct results with generators
tmp = itertools.islice(itertools.count(), 10)
assert list(splitEvery(5, tmp)) == [[0, 1, 2, 3, 4], [5, 6, 7, 8]]

Current Implementation

Here is the code I currently have but it doesn’t work with a simple list.

def splitEvery_1(n, iterable):
    res = list(itertools.islice(iterable, n))
    while len(res) != 0:
        yield res
        res = list(itertools.islice(iterable, n))

This one doesn’t work with a generator expression (thanks to jellybean for fixing it):

def splitEvery_2(n, iterable): 
    return [iterable[i:i+n] for i in range(0, len(iterable), n)]

There has to be a simple piece of code that does the splitting. I know I could just have different functions but it seems like it should be and easy thing to do. I’m probably getting stuck on an unimportant problem but it’s really bugging me.


It is similar to grouper from http://docs.python.org/library/itertools.html#itertools.groupby but I don’t want it to fill extra values.

def grouper(n, iterable, fillvalue=None):
    "grouper(3, 'ABCDEFG', 'x') --> ABC DEF Gxx"
    args = [iter(iterable)] * n
    return izip_longest(fillvalue=fillvalue, *args)

It does mention a method that truncates the last value. This isn’t what I want either.

The left-to-right evaluation order of the iterables is guaranteed. This makes possible an idiom for clustering a data series into n-length groups using izip(*[iter(s)]*n).

list(izip(*[iter(range(9))]*5)) == [[0, 1, 2, 3, 4]]
# should be [[0, 1, 2, 3, 4], [5, 6, 7, 8]]
Asked By: James Brooks

||

Answers:

Why not do it like this? Looks almost like your splitEvery_2 function.

def splitEveryN(n, it):
    return [it[i:i+n] for i in range(0, len(it), n)]

Actually it only takes away the unnecessary step interval from the slice in your solution. 🙂

Answered By: Johannes Charra

Here is how you deal with list vs iterator:

def isList(L): # Implement it somehow - returns True or false
...
return (list, lambda x:x)[int(islist(L))](result)
Answered By: Hamish Grubijan
from itertools import islice

def split_every(n, iterable):
    i = iter(iterable)
    piece = list(islice(i, n))
    while piece:
        yield piece
        piece = list(islice(i, n))

Some tests:

>>> list(split_every(5, range(9)))
[[0, 1, 2, 3, 4], [5, 6, 7, 8]]

>>> list(split_every(3, (x**2 for x in range(20))))
[[0, 1, 4], [9, 16, 25], [36, 49, 64], [81, 100, 121], [144, 169, 196], [225, 256, 289], [324, 361]]

>>> [''.join(s) for s in split_every(6, 'Hello world')]
['Hello ', 'world']

>>> list(split_every(100, []))
[]
Answered By: Roberto Bonvallet

I think those questions are almost equal

Changing a little bit to crop the last, I think a good solution for the generator case would be:

from itertools import *
def iter_grouper(n, iterable):
    it = iter(iterable)
    item = itertools.islice(it, n)
    while item:
        yield item
        item = itertools.islice(it, n)

for the object that supports slices (lists, strings, tuples), we can do:

def slice_grouper(n, sequence):
   return [sequence[i:i+n] for i in range(0, len(sequence), n)]

now it’s just a matter of dispatching the correct method:

def grouper(n, iter_or_seq):
    if hasattr(iter_or_seq, "__getslice__"):
        return slice_grouper(n, iter_or_seq)
    elif hasattr(iter_or_seq, "__iter__"):
        return iter_grouper(n, iter_or_seq)

I think you could polish it a little bit more 🙂

Answered By: fortran
def chunks(iterable,n):
    """assumes n is an integer>0
    """
    iterable=iter(iterable)
    while True:
        result=[]
        for i in range(n):
            try:
                a=next(iterable)
            except StopIteration:
                break
            else:
                result.append(a)
        if result:
            yield result
        else:
            break

g1=(i*i for i in range(10))
g2=chunks(g1,3)
print g2
'<generator object chunks at 0x0337B9B8>'
print list(g2)
'[[0, 1, 4], [9, 16, 25], [36, 49, 64], [81]]'
Answered By: Rusty Rob

this will do the trick

from itertools import izip_longest
izip_longest(it[::2], it[1::2])

where *it* is some iterable


Example:

izip_longest('abcdef'[::2], 'abcdef'[1::2]) -> ('a', 'b'), ('c', 'd'), ('e', 'f')

Let’s break this down

'abcdef'[::2] -> 'ace'
'abcdef'[1::2] -> 'bdf'

As you can see the last number in the slice is specifying the interval that will be used to pick up items. You can read more about using extended slices here.

The zip function takes the first item from the first iterable and combines it with the first item with the second iterable. The zip function then does the same thing for the second and third items until one of the iterables runs out of values.

The result is an iterator. If you want a list use the list() function on the result.

Here’s a quick one-liner version. Like Haskell’s, it is lazy.

from itertools import islice, takewhile, repeat
split_every = (lambda n, it:
    takewhile(bool, (list(islice(it, n)) for _ in repeat(None))))

This requires that you use iter before calling split_every.

Example:

list(split_every(5, iter(xrange(9))))
[[0, 1, 2, 3, 4], [5, 6, 7, 8]]

Although not a one-liner, the version below doesn’t require that you call iter which can be a common pitfall.

from itertools import islice, takewhile, repeat

def split_every(n, iterable):
    """
    Slice an iterable into chunks of n elements
    :type n: int
    :type iterable: Iterable
    :rtype: Iterator
    """
    iterator = iter(iterable)
    return takewhile(bool, (list(islice(iterator, n)) for _ in repeat(None)))

(Thanks to @eli-korvigo for improvements.)

Answered By: Elliot Cameron

This is an answer that works for both list and generator:

from itertools import count, groupby
def split_every(size, iterable):
    c = count()
    for k, g in groupby(iterable, lambda x: next(c)//size):
        yield list(g) # or yield g if you want to output a generator
Answered By: justhalf

A one-liner, inlineable solution to this (supports v2/v3, iterators, uses standard library and a single generator comprehension):

import itertools
def split_groups(iter_in, group_size):
     return ((x for _, x in item) for _, item in itertools.groupby(enumerate(iter_in), key=lambda x: x[0] // group_size))
Answered By: Andrey Cizov

If you want a solution that

  • uses generators only (no intermediate lists or tuples),
  • works for very long (or infinite) iterators,
  • works for very large batch sizes,

this does the trick:

def one_batch(first_value, iterator, batch_size):
    yield first_value
    for i in xrange(1, batch_size):
        yield iterator.next()

def batch_iterator(iterator, batch_size):
    iterator = iter(iterator)
    while True:
        first_value = iterator.next()  # Peek.
        yield one_batch(first_value, iterator, batch_size)

It works by peeking at the next value in the iterator and passing that as the first value to a generator (one_batch()) that will yield it, along with the rest of the batch.

The peek step will raise StopIteration exactly when the input iterator is exhausted and there are no more batches. Since this is the correct time to raise StopIteration in the batch_iterator() method, there is no need to catch the exception.

This will process lines from stdin in batches:

for input_batch in batch_iterator(sys.stdin, 10000):
    for line in input_batch:
        process(line)
    finalise()

I’ve found this useful for processing lots of data and uploading the results in batches to an external store.

Answered By: Carl

building off of the accepted answer and employing a lesser-known use of iter (that, when passed a second arg, it calls the first until it receives the second), you can do this really easily:

python3:

from itertools import islice

def split_every(n, iterable):
    iterable = iter(iterable)
    yield from iter(lambda: list(islice(iterable, n)), [])

python2:

def split_every(n, iterable):
    iterable = iter(iterable)
    for chunk in iter(lambda: list(islice(iterable, n)), []):
        yield chunk
Answered By: acushner

more_itertools has a chunked function:

import more_itertools as mit


list(mit.chunked(range(9), 5))
# [[0, 1, 2, 3, 4], [5, 6, 7, 8]]
Answered By: pylang

I came across this as I’m trying to chop up batches too, but doing it on a generator from a stream, so most of the solutions here aren’t applicable, or don’t work in python 3.

For people still stumbling upon this, here’s a general solution using itertools:

from itertools import islice, chain

def iter_in_slices(iterator, size=None):
    while True:
        slice_iter = islice(iterator, size)
        # If no first object this is how StopIteration is triggered
        peek = next(slice_iter)
        # Put the first object back and return slice
        yield chain([peek], slice_iter)
Answered By: Ashley Waite

A fully lazy solution for input/output of generators, including some checking.

def chunks(items, binsize):
    consumed = [0]
    sent = [0]
    it = iter(items)

    def g():
        c = 0
        while c < binsize:
            try:
                val = next(it)
            except StopIteration:
                sent[0] = None
                return
            consumed[0] += 1
            yield val
            c += 1

    while consumed[0] <= sent[0]:
        if consumed[0] < sent[0]:
            raise Exception("Cannot traverse a chunk before the previous is consumed.", consumed[0], sent[0])
        yield g()
        if sent[0] is None:
            return
        sent[0] += binsize


def g():
    for item in [1, 2, 3, 4, 5, 6, 7]:
        sleep(1)
        print(f"accessed:{item}→t", end="")
        yield item


for chunk in chunks(g(), 3):
    for x in chunk:
        print(f"x:{x}ttt", end="")
    print()

"""
Output:

accessed:1→ x:1         accessed:2→ x:2         accessed:3→ x:3         
accessed:4→ x:4         accessed:5→ x:5         accessed:6→ x:6         
accessed:7→ x:7 
"""
Answered By: dawid
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.