Kth Smallest Element in multiple sorted arrays

Question:

Let’s say we have two arrays:

array1 = [2,3,6,7,9]

array2 = [1,4,8,10]

I understood how to find the kth element of two sorted arrays in log(min(m,n)) where m is the length of array1 and n is the length of array2 as follows:

def kthelement(arr1, arr2, m, n, k):
    if m > n:
        kthelement(arr2, arr1, n, m, k) 

    low = max(0, k - m)
    high = min(k, n)

    while low <= high:
        cut1 = (low + high) >> 1 
        cut2 = k - cut1 
        l1 = MIN_VALUE if cut1 == 0 else arr1[cut1 - 1] 
        l2 = MIN_VALUE if cut2 == 0 else arr2[cut2 - 1]
        r1 = MAX_VALUE if cut1 == n else arr1[cut1]
        r2 = MAX_VALUE if cut2 == m else arr2[cut2] 
        
        if l1 <= r2 and l2 <= r1:
            print(cut1, cut2)
            return max(l1, l2)
        elif l1 > r2:
            high = cut1 - 1
        else:
            low = cut1 + 1

But I couldn’t figure out how to extend this to multiple sorted arrays case. For example, given 3 arrays, I want to find the kth element of the final sorted array.

array1 = [2,3,6,7,9]

array2 = [1,4,8,10]

array3 = [2,3,5,7]

Is it possible to achieve it in log(min(m,n)) as in the two array case?

Asked By: M.Soyturk

||

Answers:

The general solution is to use a min-heap. If you have n sorted arrays and you want the kth smallest number, then the solution is O(k log n).

The idea is that you insert the first number from each array into the min-heap. When inserting into the heap, you insert a tuple that contains the number, and the array that it came from.

You then remove the smallest value from the heap and add the next number from the array that value came from. You do this k times to get the kth smallest number.

See https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/ for the general idea.

Answered By: Jim Mischel

If k is very large, We can make binary search on the answer, which leads to a solution with time complexity O(n*logN) where N is the range of each element, and n is the number of arrays.

What we need to learn is how to check some integer x whether <= correct answer or not. We can just enumerate each array, and make binary search on it to count the number of elements less than or equal to x. accumulate them, and compare it with k.

from typing import List
import bisect

def query_k_min(vecs: List[List[int]], k: int) -> int:
    # we assume each number >=1 and <=10^9
    l, r = 0, 10**9
    while r - l > 1:
        m = (l+r)>>1
        tot = 0
        for vec in vecs:
            tot += bisect.bisect_right(vec, m)
        if tot >= k: r = m
        else: l = m
    return r

a = [[2,3,6,7,9],[1,4,8,10],[2,3,5,7]]
for x in range(1,14):
    print(query_k_min(a,x))


Answered By: rdc

The following looks complicated, but if M is the sum of the logs of len(list)+2, then the average case is O(M) and the worst case is O(M^2). (The reason for the +2 is that even if the array has no elements, we need to do work, which we do by making the log to be of at least 2.) The worst case is very unlikely.

The performance is independent of k.

The idea the same as Quickselect. We are picking pivots, and splitting data around the pivot. But we do not look at each elements, we only figure out what chunk of each array that is still under consideration is before/after/landed at the pivot. The average case is because every time we look at an array, with positive probability we get rid of half of what remains. The worst case is because every time we look at the array we got a pivot from, we will get rid of half that array but may have to binary search every other array to decide we got rid of nothing else.

from collections import deque

def kth_of_sorted (k, arrays):
    # Initialize some global variables.
    known_low = 0
    known_high = 0
    total_size = 0

    # in_flight will be a double-ended queue of
    # (array, iteration, i, j, min_i, min_j)
    # Where:
    #    array is an input array
    #    iteration is which median it was compared to
    #    low is the lower bound on where kth might be
    #    high is the upper bound on where kth might be
    in_flight = deque()

    for a in arrays:
        if 0 < len(a):
            total_size += len(a)
            in_flight.append((a, 0, len(a)-1))

    # Sanity check.
    if k < 1 or total_size < k:
        return None

    while 0 < len(in_flight):
        start_a, start_low, start_high = in_flight.popleft()
        start_mid = (start_low + start_high) // 2
        pivot = start_a[start_mid]

        # If pivot is placed, how many are known?
        maybe_low = start_mid - start_low
        maybe_high = start_high - start_mid

        # This will be arrays taken from in_flight with:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        # We are binary searching in these to figure out where the pivot
        # is going to go. Then we copy back to in_flight.
        to_process = deque()

        # This will be arrays taken from in_flight with:
        #
        #    (array, orig_low, mid, orig_high)
        #
        # where at mid we match the pivot.
        is_match = deque()
        # And we know an array with a pivot!
        is_match.append((start_a, start_low, start_mid, start_high))

        # This will be arrays taken from in_flight which we know do not have the pivot:
        #
        #    (array, low, high, orig_low, orig_high)
        #
        no_pivot = deque()

        while 0 < len(in_flight):
            a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
                 a, low, high = in_flight.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, low, high))
                else:
                    no_pivot.append((a, mid+1, high, low, high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, low, high))
                else:
                    no_pivot.append((a, low, mid-1, low, high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, low, mid, high))

        # We do not yet know where the pivot_pos is.
        pivot_pos = None
        if k <= known_low + maybe_low:
            pivot_pos = 'right'
        elif total_size - known_high - maybe_high < k:
            pivot_pos = 'left'
        elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
            return pivot # WE FOUND IT!

        while pivot_pos is None:
            # This is very similar to how we processed in_flight.
            a, low, high, orig_low, orig_high = to_process.popleft()
            mid = (low + high) // 2
            if a[mid] < pivot:
                # all of low, low+1, ..., mid are below the pivot
                maybe_low += mid + 1 - low
                if mid < high:
                    to_process.append((a, mid+1, high, orig_low, orig_high))
                else:
                    no_pivot.append((a, mid+1, high, orig_low, orig_high))
            elif pivot < a[mid]:
                # all of mid, mid+1, ..., high are above the pivot.
                maybe_high += high + 1 - mid
                if low < mid:
                    to_process.append((a, low, mid-1, orig_low, orig_high))
                else:
                    no_pivot.append((a, low, mid-1, orig_low, orig_high))
            else:
                # mid is at pivot
                maybe_low += mid - low
                maybe_high += high - mid
                is_match.append((a, orig_low, mid, orig_high))

            if k <= known_low + maybe_low:
                pivot_pos = 'right'
            elif total_size - known_high - maybe_high < k:
                pivot_pos = 'left'
            elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!
       elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
                return pivot # WE FOUND IT!

        # And now place the pivot in the right position.
        if pivot_pos == 'right':
            known_high += maybe_high + len(is_match)
            # And put back the left side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if orig_low <= high:
                        in_flight.append((a, orig_low, high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if low < mid:
                    in_flight.append((a, low, mid-1))
        else:
            known_low += maybe_low + len(is_match)
            # And put back the right side of each nonemptied array.
            for q in (to_process, no_pivot):
                while 0 < len(q):
                    a, low, high, orig_low, orig_high = q.popleft()
                    if low <= orig_high:
                        in_flight.append((a, low, orig_high))
            while 0 < len(is_match):
                a, low, mid, high = is_match.popleft()
                if mid < high:
                    in_flight.append((a, mid+1, high))

list1 = [2,3,6,7,9]
list2 = [1,4,8,10]
list3 = [2,3,5,7]
print(list1, list2, list3)
for i in range(1, len(list1) + len(list2) + len(list3)):
    print(i, kth_of_sorted(i,[list1, list2, list3]))
Answered By: btilly
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.