Kth Smallest Element in multiple sorted arrays
Question:
Let’s say we have two arrays:
array1 = [2,3,6,7,9]
array2 = [1,4,8,10]
I understood how to find the kth element of two sorted arrays in log(min(m,n)) where m is the length of array1 and n is the length of array2 as follows:
def kthelement(arr1, arr2, m, n, k):
if m > n:
kthelement(arr2, arr1, n, m, k)
low = max(0, k - m)
high = min(k, n)
while low <= high:
cut1 = (low + high) >> 1
cut2 = k - cut1
l1 = MIN_VALUE if cut1 == 0 else arr1[cut1 - 1]
l2 = MIN_VALUE if cut2 == 0 else arr2[cut2 - 1]
r1 = MAX_VALUE if cut1 == n else arr1[cut1]
r2 = MAX_VALUE if cut2 == m else arr2[cut2]
if l1 <= r2 and l2 <= r1:
print(cut1, cut2)
return max(l1, l2)
elif l1 > r2:
high = cut1 - 1
else:
low = cut1 + 1
But I couldn’t figure out how to extend this to multiple sorted arrays case. For example, given 3 arrays, I want to find the kth element of the final sorted array.
array1 = [2,3,6,7,9]
array2 = [1,4,8,10]
array3 = [2,3,5,7]
Is it possible to achieve it in log(min(m,n)) as in the two array case?
Answers:
The general solution is to use a min-heap. If you have n
sorted arrays and you want the kth smallest number, then the solution is O(k log n).
The idea is that you insert the first number from each array into the min-heap. When inserting into the heap, you insert a tuple that contains the number, and the array that it came from.
You then remove the smallest value from the heap and add the next number from the array that value came from. You do this k times to get the kth smallest number.
See https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/ for the general idea.
If k is very large, We can make binary search on the answer, which leads to a solution with time complexity O(n*logN)
where N is the range of each element, and n
is the number of arrays.
What we need to learn is how to check some integer x
whether <=
correct answer or not. We can just enumerate each array, and make binary search on it to count the number of elements less than or equal to x. accumulate them, and compare it with k
.
from typing import List
import bisect
def query_k_min(vecs: List[List[int]], k: int) -> int:
# we assume each number >=1 and <=10^9
l, r = 0, 10**9
while r - l > 1:
m = (l+r)>>1
tot = 0
for vec in vecs:
tot += bisect.bisect_right(vec, m)
if tot >= k: r = m
else: l = m
return r
a = [[2,3,6,7,9],[1,4,8,10],[2,3,5,7]]
for x in range(1,14):
print(query_k_min(a,x))
The following looks complicated, but if M
is the sum of the logs of len(list)+2
, then the average case is O(M)
and the worst case is O(M^2)
. (The reason for the +2 is that even if the array has no elements, we need to do work, which we do by making the log to be of at least 2.) The worst case is very unlikely.
The performance is independent of k
.
The idea the same as Quickselect. We are picking pivots, and splitting data around the pivot. But we do not look at each elements, we only figure out what chunk of each array that is still under consideration is before/after/landed at the pivot. The average case is because every time we look at an array, with positive probability we get rid of half of what remains. The worst case is because every time we look at the array we got a pivot from, we will get rid of half that array but may have to binary search every other array to decide we got rid of nothing else.
from collections import deque
def kth_of_sorted (k, arrays):
# Initialize some global variables.
known_low = 0
known_high = 0
total_size = 0
# in_flight will be a double-ended queue of
# (array, iteration, i, j, min_i, min_j)
# Where:
# array is an input array
# iteration is which median it was compared to
# low is the lower bound on where kth might be
# high is the upper bound on where kth might be
in_flight = deque()
for a in arrays:
if 0 < len(a):
total_size += len(a)
in_flight.append((a, 0, len(a)-1))
# Sanity check.
if k < 1 or total_size < k:
return None
while 0 < len(in_flight):
start_a, start_low, start_high = in_flight.popleft()
start_mid = (start_low + start_high) // 2
pivot = start_a[start_mid]
# If pivot is placed, how many are known?
maybe_low = start_mid - start_low
maybe_high = start_high - start_mid
# This will be arrays taken from in_flight with:
#
# (array, low, high, orig_low, orig_high)
#
# We are binary searching in these to figure out where the pivot
# is going to go. Then we copy back to in_flight.
to_process = deque()
# This will be arrays taken from in_flight with:
#
# (array, orig_low, mid, orig_high)
#
# where at mid we match the pivot.
is_match = deque()
# And we know an array with a pivot!
is_match.append((start_a, start_low, start_mid, start_high))
# This will be arrays taken from in_flight which we know do not have the pivot:
#
# (array, low, high, orig_low, orig_high)
#
no_pivot = deque()
while 0 < len(in_flight):
a, low, high = in_flight.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, low, high))
else:
no_pivot.append((a, mid+1, high, low, high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, low, high))
else:
no_pivot.append((a, low, mid-1, low, high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, low, mid, high))
# We do not yet know where the pivot_pos is.
pivot_pos = None
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
while pivot_pos is None:
# This is very similar to how we processed in_flight.
a, low, high, orig_low, orig_high = to_process.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, orig_low, orig_high))
else:
no_pivot.append((a, mid+1, high, orig_low, orig_high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, orig_low, orig_high))
else:
no_pivot.append((a, low, mid-1, orig_low, orig_high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, orig_low, mid, orig_high))
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
a, low, high = in_flight.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, low, high))
else:
no_pivot.append((a, mid+1, high, low, high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, low, high))
else:
no_pivot.append((a, low, mid-1, low, high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, low, mid, high))
# We do not yet know where the pivot_pos is.
pivot_pos = None
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
while pivot_pos is None:
# This is very similar to how we processed in_flight.
a, low, high, orig_low, orig_high = to_process.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, orig_low, orig_high))
else:
no_pivot.append((a, mid+1, high, orig_low, orig_high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, orig_low, orig_high))
else:
no_pivot.append((a, low, mid-1, orig_low, orig_high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, orig_low, mid, orig_high))
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
# And now place the pivot in the right position.
if pivot_pos == 'right':
known_high += maybe_high + len(is_match)
# And put back the left side of each nonemptied array.
for q in (to_process, no_pivot):
while 0 < len(q):
a, low, high, orig_low, orig_high = q.popleft()
if orig_low <= high:
in_flight.append((a, orig_low, high))
while 0 < len(is_match):
a, low, mid, high = is_match.popleft()
if low < mid:
in_flight.append((a, low, mid-1))
else:
known_low += maybe_low + len(is_match)
# And put back the right side of each nonemptied array.
for q in (to_process, no_pivot):
while 0 < len(q):
a, low, high, orig_low, orig_high = q.popleft()
if low <= orig_high:
in_flight.append((a, low, orig_high))
while 0 < len(is_match):
a, low, mid, high = is_match.popleft()
if mid < high:
in_flight.append((a, mid+1, high))
list1 = [2,3,6,7,9]
list2 = [1,4,8,10]
list3 = [2,3,5,7]
print(list1, list2, list3)
for i in range(1, len(list1) + len(list2) + len(list3)):
print(i, kth_of_sorted(i,[list1, list2, list3]))
Let’s say we have two arrays:
array1 = [2,3,6,7,9]
array2 = [1,4,8,10]
I understood how to find the kth element of two sorted arrays in log(min(m,n)) where m is the length of array1 and n is the length of array2 as follows:
def kthelement(arr1, arr2, m, n, k):
if m > n:
kthelement(arr2, arr1, n, m, k)
low = max(0, k - m)
high = min(k, n)
while low <= high:
cut1 = (low + high) >> 1
cut2 = k - cut1
l1 = MIN_VALUE if cut1 == 0 else arr1[cut1 - 1]
l2 = MIN_VALUE if cut2 == 0 else arr2[cut2 - 1]
r1 = MAX_VALUE if cut1 == n else arr1[cut1]
r2 = MAX_VALUE if cut2 == m else arr2[cut2]
if l1 <= r2 and l2 <= r1:
print(cut1, cut2)
return max(l1, l2)
elif l1 > r2:
high = cut1 - 1
else:
low = cut1 + 1
But I couldn’t figure out how to extend this to multiple sorted arrays case. For example, given 3 arrays, I want to find the kth element of the final sorted array.
array1 = [2,3,6,7,9]
array2 = [1,4,8,10]
array3 = [2,3,5,7]
Is it possible to achieve it in log(min(m,n)) as in the two array case?
The general solution is to use a min-heap. If you have n
sorted arrays and you want the kth smallest number, then the solution is O(k log n).
The idea is that you insert the first number from each array into the min-heap. When inserting into the heap, you insert a tuple that contains the number, and the array that it came from.
You then remove the smallest value from the heap and add the next number from the array that value came from. You do this k times to get the kth smallest number.
See https://www.geeksforgeeks.org/find-m-th-smallest-value-in-k-sorted-arrays/ for the general idea.
If k is very large, We can make binary search on the answer, which leads to a solution with time complexity O(n*logN)
where N is the range of each element, and n
is the number of arrays.
What we need to learn is how to check some integer x
whether <=
correct answer or not. We can just enumerate each array, and make binary search on it to count the number of elements less than or equal to x. accumulate them, and compare it with k
.
from typing import List
import bisect
def query_k_min(vecs: List[List[int]], k: int) -> int:
# we assume each number >=1 and <=10^9
l, r = 0, 10**9
while r - l > 1:
m = (l+r)>>1
tot = 0
for vec in vecs:
tot += bisect.bisect_right(vec, m)
if tot >= k: r = m
else: l = m
return r
a = [[2,3,6,7,9],[1,4,8,10],[2,3,5,7]]
for x in range(1,14):
print(query_k_min(a,x))
The following looks complicated, but if M
is the sum of the logs of len(list)+2
, then the average case is O(M)
and the worst case is O(M^2)
. (The reason for the +2 is that even if the array has no elements, we need to do work, which we do by making the log to be of at least 2.) The worst case is very unlikely.
The performance is independent of k
.
The idea the same as Quickselect. We are picking pivots, and splitting data around the pivot. But we do not look at each elements, we only figure out what chunk of each array that is still under consideration is before/after/landed at the pivot. The average case is because every time we look at an array, with positive probability we get rid of half of what remains. The worst case is because every time we look at the array we got a pivot from, we will get rid of half that array but may have to binary search every other array to decide we got rid of nothing else.
from collections import deque
def kth_of_sorted (k, arrays):
# Initialize some global variables.
known_low = 0
known_high = 0
total_size = 0
# in_flight will be a double-ended queue of
# (array, iteration, i, j, min_i, min_j)
# Where:
# array is an input array
# iteration is which median it was compared to
# low is the lower bound on where kth might be
# high is the upper bound on where kth might be
in_flight = deque()
for a in arrays:
if 0 < len(a):
total_size += len(a)
in_flight.append((a, 0, len(a)-1))
# Sanity check.
if k < 1 or total_size < k:
return None
while 0 < len(in_flight):
start_a, start_low, start_high = in_flight.popleft()
start_mid = (start_low + start_high) // 2
pivot = start_a[start_mid]
# If pivot is placed, how many are known?
maybe_low = start_mid - start_low
maybe_high = start_high - start_mid
# This will be arrays taken from in_flight with:
#
# (array, low, high, orig_low, orig_high)
#
# We are binary searching in these to figure out where the pivot
# is going to go. Then we copy back to in_flight.
to_process = deque()
# This will be arrays taken from in_flight with:
#
# (array, orig_low, mid, orig_high)
#
# where at mid we match the pivot.
is_match = deque()
# And we know an array with a pivot!
is_match.append((start_a, start_low, start_mid, start_high))
# This will be arrays taken from in_flight which we know do not have the pivot:
#
# (array, low, high, orig_low, orig_high)
#
no_pivot = deque()
while 0 < len(in_flight):
a, low, high = in_flight.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, low, high))
else:
no_pivot.append((a, mid+1, high, low, high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, low, high))
else:
no_pivot.append((a, low, mid-1, low, high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, low, mid, high))
# We do not yet know where the pivot_pos is.
pivot_pos = None
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
while pivot_pos is None:
# This is very similar to how we processed in_flight.
a, low, high, orig_low, orig_high = to_process.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, orig_low, orig_high))
else:
no_pivot.append((a, mid+1, high, orig_low, orig_high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, orig_low, orig_high))
else:
no_pivot.append((a, low, mid-1, orig_low, orig_high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, orig_low, mid, orig_high))
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
a, low, high = in_flight.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, low, high))
else:
no_pivot.append((a, mid+1, high, low, high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, low, high))
else:
no_pivot.append((a, low, mid-1, low, high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, low, mid, high))
# We do not yet know where the pivot_pos is.
pivot_pos = None
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
while pivot_pos is None:
# This is very similar to how we processed in_flight.
a, low, high, orig_low, orig_high = to_process.popleft()
mid = (low + high) // 2
if a[mid] < pivot:
# all of low, low+1, ..., mid are below the pivot
maybe_low += mid + 1 - low
if mid < high:
to_process.append((a, mid+1, high, orig_low, orig_high))
else:
no_pivot.append((a, mid+1, high, orig_low, orig_high))
elif pivot < a[mid]:
# all of mid, mid+1, ..., high are above the pivot.
maybe_high += high + 1 - mid
if low < mid:
to_process.append((a, low, mid-1, orig_low, orig_high))
else:
no_pivot.append((a, low, mid-1, orig_low, orig_high))
else:
# mid is at pivot
maybe_low += mid - low
maybe_high += high - mid
is_match.append((a, orig_low, mid, orig_high))
if k <= known_low + maybe_low:
pivot_pos = 'right'
elif total_size - known_high - maybe_high < k:
pivot_pos = 'left'
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
elif k <= known_low + maybe_low + len(is_match) and total_size < k + known_high + maybe_high + len(is_match):
return pivot # WE FOUND IT!
# And now place the pivot in the right position.
if pivot_pos == 'right':
known_high += maybe_high + len(is_match)
# And put back the left side of each nonemptied array.
for q in (to_process, no_pivot):
while 0 < len(q):
a, low, high, orig_low, orig_high = q.popleft()
if orig_low <= high:
in_flight.append((a, orig_low, high))
while 0 < len(is_match):
a, low, mid, high = is_match.popleft()
if low < mid:
in_flight.append((a, low, mid-1))
else:
known_low += maybe_low + len(is_match)
# And put back the right side of each nonemptied array.
for q in (to_process, no_pivot):
while 0 < len(q):
a, low, high, orig_low, orig_high = q.popleft()
if low <= orig_high:
in_flight.append((a, low, orig_high))
while 0 < len(is_match):
a, low, mid, high = is_match.popleft()
if mid < high:
in_flight.append((a, mid+1, high))
list1 = [2,3,6,7,9]
list2 = [1,4,8,10]
list3 = [2,3,5,7]
print(list1, list2, list3)
for i in range(1, len(list1) + len(list2) + len(list3)):
print(i, kth_of_sorted(i,[list1, list2, list3]))