What is the most efficient way of getting the intersection of k sorted arrays?

Question:

Given k sorted arrays what is the most efficient way of getting the intersection of these lists

Example

INPUT:

[[1,3,5,7], [1,1,3,5,7], [1,4,7,9]] 

Output:

[1,7]

There is a way to get the union of k sorted arrays based on what I read in the Elements of programming interviews book in nlogk time. I was wondering if there is a way to do something similar for the intersection as well

## merge sorted arrays in nlogk time [ regular appending and merging is nlogn time ]
import heapq
def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]
    
    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))
    
    res = []
 
    # collect results in nlogK time
    while heap:
        elem, ary = heapq.heappop(heap)
        it = srtd_iters[ary]
        res.append(elem)
        nxt = next(it, None)
        if nxt:
            heapq.heappush(heap, (nxt, ary))

EDIT: obviously this is an algorithm question that I am trying to solve so I cannot use any of the inbuilt functions like set intersection etc

Asked By: identical123456

||

Answers:

You can use reduce:

from functools import reduce

a = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
reduce(lambda x, y: x & set(y), a[1:], set(a[0]))
 {1, 7}
Answered By: onyambu

You can use builtin sets and sets intersections :

d = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]] 
result = set(d[0]).intersection(*d[1:])
{1, 7}
Answered By: AlexTorx

Exploiting sort order

Here is a single pass O(n) approach that doesn’t require any special data structures or auxiliary memory beyond the fundamental requirement of one iterator per input.

from itertools import cycle, islice

def intersection(inputs):
    "Yield the intersection of elements from multiple sorted inputs."
    # intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
    n = len(inputs)
    iters = cycle(map(iter, inputs))
    try:
        candidate = next(next(iters))
        while True:
            for it in islice(iters, n-1):
                while (value := next(it)) < candidate:
                    pass
                if value != candidate:
                    candidate = value
                    break
            else:
                yield candidate
                candidate = next(next(iters))
    except StopIteration:
        return

Here’s a sample session:

>>> data = [[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]
>>> list(intersection(data))
[1, 7]

>>> data = [[1,1,2,3], [1,1,4,4]]
>>> list(intersection(data))
[1, 1]

Algorithm in words

The algorithm starts by selecting the next value from the next iterator to be a candidate.

The main loop assumes a candidate has been selected and it loops over the next n - 1 iterators. For each of those iterators, it consumes values until it finds a value that is a least as large as the candidate. If that value is larger than the candidate, that value becomes the new candidate and the main loop starts again. If all n - 1 values are equal to the candidate, then the candidate is emitted and a new candidate is fetched.

When any input iterator is exhausted, the algorithm is complete.

Doing it without libraries (core language only)

The same algorithm works fine (though less beautifully) without using itertools. Just replace cycle and islice with their list based equivalents:

def intersection(inputs):
    "Yield the intersection of elements from multiple sorted inputs."
    # intersection(['ABBCD', 'BBDE', 'BBBDDE']) --> B B D
    n = len(inputs)
    iters = list(map(iter, inputs))
    curr_iter = 0
    try:
        it = iters[curr_iter]
        curr_iter = (curr_iter + 1) % n
        candidate = next(it)
        while True:
            for i in range(n - 1):
                it = iters[curr_iter]
                curr_iter = (curr_iter + 1) % n
                while (value := next(it)) < candidate:
                    pass
                if value != candidate:
                    candidate = value
                    break
            else:
                yield candidate
                it = iters[curr_iter]
                curr_iter = (curr_iter + 1) % n
                candidate = next(it)
    except StopIteration:
        return
Answered By: Raymond Hettinger

Yes, it is possible! I’ve modified your example code to do this.

My answer assumes that your question is about the algorithm – if you want the fastest-running code using sets, see other answers.

This maintains the O(n log(k)) time complexity: all the code between if lowest != elem or ary != times_seen: and unbench_all = False is O(log(k)). There is a nested loop inside the main loop (for unbenched in range(times_seen):) but this only runs times_seen times, and times_seen is initially 0 and is reset to 0 after every time this inner loop is run, and can only be incremented once per main loop iteration, so the inner loop cannot do more iterations in total than the main loop. Thus, since the code inside the inner loop is O(log(k)) and runs at most as many times as the outer loop, and the outer loop is O(log(k)) and runs n times, the algorithm is O(n log(k)).

This algorithm relies upon how tuples are compared in Python. It compares the first items of the tuples, and if they are equal it, compares the second items (i.e. (x, a) < (x, b) is true if and only if a < b).
In this algorithm, unlike in the example code in the question, when an item is popped from the heap, it is not necessarily pushed again in the same iteration. Since we need to check if all sub-lists contain the same number, after a number is popped from the heap, it’s sublist is what I call "benched", meaning that it is not added back to the heap. This is because we need to check if other sub-lists contain the same item, so adding this sub-list’s next item is not needed right now.

If a number is indeed in all sub-lists, then the heap will look something like [(2,0),(2,1),(2,2),(2,3)], with all the first elements of the tuples the same, so heappop will select the one with the lowest sub-list index. This means that first index 0 will be popped and times_seen will be incremented to 1, then index 1 will be popped and times_seen will be incremented to 2 – if ary is not equal to times_seen then the number is not in the intersection of all sub-lists. This leads to the condition if lowest != elem or ary != times_seen:, which decides when a number shouldn’t be in the result. The else branch of this if statement is for when it still might be in the result.

The unbench_all boolean is for when all sub-lists need to be removed from the bench – this could be because:

  1. The current number is known to not be in the intersection of the sub-lists
  2. It is known to be in the intersection of the sub-lists

When unbench_all is True, all the sub-lists that were removed from the heap are re-added. It is known that these are the ones with indices in range(times_seen) since the algorithm removes items from the heap only if they have the same number, so they must have been removed in order of index, contiguously and starting from index 0, and there must be times_seen of them. This means that we don’t need to store the indices of the benched sub-lists, only the number that have been benched.

import heapq


def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # the number of tims that the current number has been seen
    times_seen = 0

    # the lowest number from the heap - currently checking if the first numbers in all sub-lists are equal to this
    lowest = heap[0][0] if heap else None

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        unbench_all = True

        if lowest != elem or ary != times_seen:
            if lowest == elem:
                heapq.heappop(heap)
                it = srtd_iters[ary]
                nxt = next(it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, ary))
        else:
            heapq.heappop(heap)
            times_seen += 1

            if times_seen == len(srtd_arys):
                res.append(elem)
            else:
                unbench_all = False

        if unbench_all:
            for unbenched in range(times_seen):
                unbenched_it = srtd_iters[unbenched]
                nxt = next(unbenched_it, None)
                if nxt:
                    heapq.heappush(heap, (nxt, unbenched))
            times_seen = 0
            if heap:
                lowest = heap[0][0]

    return res


if __name__ == '__main__':
    a1 = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]
    a2 = [[1, 1], [1, 1, 2, 2, 3]]
    for arys in [a1, a2]:
        print(mergeArys(arys))

An equivalent algorithm can be written like this, if you prefer:

def mergeArys(srtd_arys):
    heap = []
    srtd_iters = [iter(x) for x in srtd_arys]

    # put the first element from each srtd array onto the heap
    for idx, it in enumerate(srtd_iters):
        elem = next(it, None)
        if elem:
            heapq.heappush(heap, (elem, idx))

    res = []

    # collect results in nlogK time
    while heap:
        elem, ary = heap[0]
        lowest = elem
        keep_elem = True
        for i in range(len(srtd_arys)):
            elem, ary = heap[0]
            if lowest != elem or ary != i:
                if ary != i:
                    heapq.heappop(heap)
                    it = srtd_iters[ary]
                    nxt = next(it, None)
                    if nxt:
                        heapq.heappush(heap, (nxt, ary))

                keep_elem = False
                i -= 1
                break
            heapq.heappop(heap)

        if keep_elem:
            res.append(elem)

        for unbenched in range(i+1):
            unbenched_it = srtd_iters[unbenched]
            nxt = next(unbenched_it, None)
            if nxt:
                heapq.heappush(heap, (nxt, unbenched))

        if len(heap) < len(srtd_arys):
            heap = []

    return res

Answered By: Oli

I’ve come up with this algorithm. It doesn’t exceed O(nk) I don’t know if it’s good enough for you. the point of this algorithm is that you can have k indexes for each array and each iteration you find the indexes of the next element in the intersection and increase every index until you exceed the bounds of an array and there are no more items in the intersection. the trick is since the arrays are sorted you can look at two elements in two different arrays and if one is bigger than the other you can instantly throw away the other because you know you cant have a smaller number than the one you are looking at. the worst case of this algorithm is that every index will be increased to the bound which takes kn time since an index cannot decrease its value.

  inter = []

  for n in range(len(arrays[0])):
    if indexes[0] >= len(arrays[0]):
        return inter
    for i in range(1,k):
      if indexes[i] >= len(arrays[i]):
        return inter
      while indexes[i] < len(arrays[i]) and arrays[i][indexes[i]] < arrays[0][indexes[0]]:
        indexes[i] += 1
      while indexes[i] < len(arrays[i]) and indexes[0] < len(arrays[0]) and arrays[i][indexes[i]] > arrays[0][indexes[0]]:
        indexes[0] += 1
    if indexes[0] < len(arrays[0]):
      inter.append(arrays[0][indexes[0]])
    indexes = [idx+1 for idx in indexes]
  return inter
Answered By: egjlmn1

You can use bit-masking with one-hot encoding. The inner lists become maxterms. You and them together for the intersection and or them for the union. Then you have to convert back, for which I’ve used a bit hack.

problem = [[1,3,5,7],[1,1,3,5,8,7],[1,4,7,9]];

debruijn = [0, 1, 28, 2, 29, 14, 24, 3, 30, 22, 20, 15, 25, 17, 4, 8,
    31, 27, 13, 23, 21, 19, 16, 7, 26, 12, 18, 6, 11, 5, 10, 9];
u32 = accum = (1 << 32) - 1;
for vec in problem:
    maxterm = 0;
    for v in vec:
        maxterm |= 1 << v;
    accum &= maxterm;

# https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogDeBruijn
result = [];
while accum:
    power = accum;
    accum &= accum - 1; # Peter Wegner CACM 3 (1960), 322
    power &= ~accum;
    result.append(debruijn[((power * 0x077CB531) & u32) >> 27]);

print result;

This uses (simulates) 32-bit integers, so you can only have [0, 31] in your sets.

*I am inexperienced at Python, so I timed it. One should definitely use set.intersection.

O(n) but Sets is 5.5 times faster.

Answered By: Neil

You said we can’t use sets but how about dicts / hash tables? (yes I know they’re basically the same thing) 😀

If so, here’s a fairly simple approach (please excuse the py2 syntax):

arrays = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
counts = {}

for ar in arrays:
  last = None
  for i in ar:
    if (i != last):
      counts[i] = counts.get(i, 0) + 1
    last = i

N = len(arrays)
intersection = [i for i, n in counts.iteritems() if n == N]
print intersection
Answered By: dimven

Same as Raymond Hettinger’s solution but with more basic python code:

def intersection(arrays, unique: bool=False):
    result = []
    if not len(arrays) or any(not len(array) for array in arrays):
        return result

    pointers = [0] * len(arrays)

    target = arrays[0][0]
    start_step = 0
    current_step = 1
    while True:
        idx = current_step % len(arrays)
        array = arrays[idx]

        while pointers[idx] < len(array) and array[pointers[idx]] < target:
            pointers[idx] += 1

        if pointers[idx] < len(array) and array[pointers[idx]] > target:
            target = array[pointers[idx]]
            start_step = current_step
            current_step += 1
            continue

        if unique:
            while (
                pointers[idx] + 1 < len(array)
                and array[pointers[idx]] == array[pointers[idx] + 1]
            ):
                pointers[idx] += 1

        if (current_step - start_step) == len(arrays):
            result.append(target)
            for other_idx, other_array in enumerate(arrays):
                pointers[other_idx] += 1
            if pointers[idx] < len(array):
                target = array[pointers[idx]]
                start_step = current_step

        if pointers[idx] == len(array):
            return result

        current_step += 1
Answered By: n1amr

Here’s an O(n) answer (where n = sum(len(sublist) for sublist in data)).

from itertools import cycle

def intersection(data):
    result = []    
    maxval = float("-inf")
    consecutive = 0
    try:
        for sublist in cycle(iter(sublist) for sublist in data):

            value = next(sublist)
            while value < maxval:
                value = next(sublist)

            if value > maxval:
                maxval = value
                consecutive = 0
                continue

            consecutive += 1
            if consecutive >= len(data)-1:
                result.append(maxval)
                consecutive = 0

    except StopIteration:
        return result

print(intersection([[1,3,5,7], [1,1,3,5,7], [1,4,7,9]]))

[1, 7]

Answered By: sradc

Some of the above methods are not covering the examples when there are duplicates in every subset of the list. The Below code implements this intersection and it will be more efficient if there are lots of duplicates in the subset of the list 🙂 If not sure about duplicates it is recommended to use Counter from collections from collections import Counter. The custom counter function is made for increasing the efficiency of handling large duplicates. But still can not beat Raymond Hettinger’s implementation.

def counter(my_list):
    my_list = sorted(my_list)
    first_val, *all_val = my_list
    p_index = my_list.index(first_val)
    my_counter = {}
    for item in all_val:
         c_index = my_list.index(item)
         diff = abs(c_index-p_index)
         p_index = c_index
         my_counter[first_val] = diff 
         first_val = item
    c_index = my_list.index(item)
    diff = len(my_list) - c_index
    my_counter[first_val] = diff 
    return my_counter

def my_func(data):
    if not data or not isinstance(data, list):
        return
    # get the first value
    first_val, *all_val = data
    if not isinstance(first_val, list):
        return
    # count items in first value
    p = counter(first_val) # counter({1: 2, 3: 1, 5: 1, 7: 1})
    # collect all common items and calculate the minimum occurance in intersection
    for val in all_val:
        # collecting common items
        c = counter(val)
        # calculate the minimum occurance in intersection
        inner_dict = {}
        for inner_val in set(c).intersection(set(p)):
            inner_dict[inner_val] = min(p[inner_val], c[inner_val])
        p = inner_dict
    # >>>p
    # {1: 2, 7: 1}
    # Sort by keys of counter
    sorted_items = sorted(p.items(), key=lambda x:x[0]) # [(1, 2), (7, 1)]
    result=[i[0] for i in sorted_items for _ in range(i[1])] # [1, 1, 7]
    return result

Here are the sample Examples

>>> data = [[1,3,5,7],[1,1,3,5,7],[1,4,7,9]]
>>> my_func(data=data)
[1, 7]
>>> data = [[1,1,3,5,7],[1,1,3,5,7],[1,1,4,7,9]]
>>> my_func(data=data)
[1, 1, 7]
Answered By: Vaibhav Hiwase

You can do the following using the functions heapq.merge, chain.from_iterable and groupby

from heapq import merge
from itertools import groupby, chain

ls = [[1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 4, 7, 9]]


def index_groups(lst):
    """[1, 1, 3, 5, 7] -> [(1, 0), (1, 1), (3, 0), (5, 0), (7, 0)]"""
    return chain.from_iterable(((e, i) for i, e in enumerate(group)) for k, group in groupby(lst))


iterables = (index_groups(li) for li in ls)
flat = merge(*iterables)
res = [k for (k, _), g in groupby(flat) if sum(1 for _ in g) == len(ls)]
print(res)

Output

[1, 7]

The idea is to give an extra value (using enumerate) to differentiate between equal values within the same list (see the function index_groups).

The complexity of this algorithm is O(n) where n is the sum of the lengths of each list in the input.

Note that the output for (an extra 1 en each list):

ls = [[1, 1, 3, 5, 7], [1, 1, 3, 5, 7], [1, 1, 4, 7, 9]]

is:

[1, 1, 7]
Answered By: Dani Mesejo

Here is the single-pass counting algorithm, a simplified version of what others have suggested.

def intersection(iterables):
    target, count = None, 0
    for it in itertools.cycle(map(iter, iterables)):
        for value in it:
            if count == 0 or value > target:
                target, count = value, 1
                break
            if value == target:
                count += 1
                break
        else:  # exhausted iterator
            return
        if count >= len(iterables):
            yield target
            count = 0

Binary and exponential search haven’t come up yet. They’re easily recreated even with the "no builtins" constraint.

In practice, that would be much faster, and sub-linear. In the worst case – where the intersection isn’t shrinking – the naive approach would repeat work. But there’s a solution for that: integrate the binary search while splitting the arrays in half.

def intersection(seqs):
    seq = min(seqs, key=len)
    if not seq:
        return
    pivot = seq[len(seq) // 2]
    lows, counts, highs = [], [], []
    for seq in seqs:
        start = bisect.bisect_left(seq, pivot)
        stop = bisect.bisect_right(seq, pivot, start)
        lows.append(seq[:start])
        counts.append(stop - start)
        highs.append(seq[stop:])
    yield from intersection(lows)
    yield from itertools.repeat(pivot, min(counts))
    yield from intersection(highs)

Both handle duplicates. Both guarantee O(N) worst-case time (counting slicing as atomic). The latter will approach O(min_size) speed; by always splitting the smallest in half it essentially can’t suffer from the bad luck of uneven splits.

Answered By: A. Coady

I couldn’t help but notice that this is seems to be a variation on the Welfare Crook problem; see David Gries’s book, The Science of Programming. Edsger Dijkstra also wrote an EWD about this, see Ascending Functions and the Welfare Crook.

The Welfare Crook

Suppose we have three long magnetic tapes, each containing a list of names in alphabetical order:

  • all people working for IBM Yorktown
  • students at Columbia University
  • people on welfare in New York City

Practically speaking, all three lists are endless, so no upper bounds are given. It is know that at least one person is on all three lists. Write a program to locate the first such person.

Our intersection of the ordered lists problem is a generalization of the Welfare Crook problem.

Here’s a (rather primitive?) Python solution to the Welfare Crook problem:

def find_welfare_crook(f, g, h, i, j, k):
    """f, g, and h are "ascending functions," i.e.,
i <= j implies f[i] <= f[j] or, equivalently,
f[i] < f[j] implies i < j, and the same goes for g and h.
i, j, k define where to start the search in each list.
"""
    # This is an implementation of a solution to the Welfare Crook
    # problems presented in David Gries's book, The Science of Programming.
    # The surprising and beautiful thing is that the guard predicates are
    # so few and so simple.
    i , j , k = i , j , k
    while True:
        if f[i] < g[j]:
            i += 1
        elif g[j] < h[k]:
            j += 1
        elif h[k] < f[i]:
            k += 1
        else:
            break
    return (i,j,k)
    # The other remarkable thing is how the negation of the guard
    # predicates works out to be:  f[i] == g[j] and g[j] == c[k].

Generalization to Intersection of K Lists

This generalizes to K lists, and here’s what I devised; I don’t know how Pythonic this is, but it pretty compact:

def findIntersectionLofL(lofl):
    """Generalized findIntersection function which operates on a "list of lists." """
    K = len(lofl)
    indices = [0 for i in range(K)]
    result = []
    #
    try:
        while True:
            # idea is to maintain the indices via a construct like the following:
            allEqual = True
            for i in range(K):
                if lofl[i][indices[i]] < lofl[(i+1)%K][indices[(i+1)%K]] :
                    indices[i] += 1
                    allEqual = False
            # When the above iteration finishes, if all of the list
            # items indexed by the indices are equal, then another
            # item common to all of the lists must be added to the result.
            if allEqual :
                result.append(lofl[0][indices[0]])
                while lofl[0][indices[0]] == lofl[1][indices[1]]:
                    indices[0] += 1
    except IndexError as e:
        # Eventually, the foregoing iteration will advance one of the
        # indices past the end of one of the lists, and when that happens
        # an IndexError exception will be raised.  This means the algorithm
        # is finished.
        return result

This solution does not keep repeated items. Changing the program to include all of the repeated items by changing what the program does in the conditional at the end of the "while True" loop is an exercise left to the reader.

Improved Performance

Comments from @greybeard prompted refinements shown below, in the
pre-computation of the "array index moduli" (the "(i+1)%K" expressions) and further investigation also brought about changes to the inner iteration’s structure, to further remove overhead:

def findIntersectionLofLunRolled(lofl):
    """Generalized findIntersection function which operates on a "list of lists."
Accepts a list-of-lists, lofl.  Each of the lists must be ordered.
Returns the list of each element which appears in all of the lists at least once.
"""
    K = len(lofl)
    indices = [0] * K
    result = []
    lt = [ (i, (i+1) % K) for i in range(K) ] # avoids evaluation of index exprs inside the loop
    #
    try:
        while True:
            allUnEqual = True
            while allUnEqual:
                allUnEqual = False
                for i,j in lt:
                    if lofl[i][indices[i]] < lofl[j][indices[j]]:
                        indices[i] += 1
                        allUnEqual = True
            # Now all of the lofl[i][indices[i]], for all i, are the same value.
            # Store that value in the result, and then advance all of the indices
            # past that common value:
            v = lofl[0][indices[0]]
            result.append(v)
            for i,j in lt:
                while lofl[i][indices[i]] == v:
                    indices[i] += 1
    except IndexError as e:
        # Eventually, the foregoing iteration will advance one of the
        # indices past the end of one of the lists, and when that happens
        # an IndexError exception will be raised.  This means the algorithm
        # is finished.
        return result
Answered By: Rich Stewart
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.