What is a quick way to count the number of pairs in a list where a XOR b is greater than a AND b?

Question:

I have an array of numbers, I want to count all possible combination of pairs for which the xor operation for that pair is greater than and operation.

Example:

4,3,5,2

possible pairs are:

(4,3) -> xor=7, and = 0
(4,5) -> xor=1, and = 4
(4,2) -> xor=6, and = 0
(3,5) -> xor=6, and = 1
(3,2) -> xor=1, and = 2
(5,2) -> xor=7, and = 0

Valid pairs for which xor > and are (4,3), (4,2), (3,5), (5,2) so result is 4.

This is my program:

def solve(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans

Time complexity is O(n^2) , but my array size is 1 to 10^5 and each element in array is 1 to 2^30. So how can I reduce time complexity of this program.

Asked By: learner

||

Answers:

Say a and b are integers. Then a^b > a&b iff a and b have different highest set bits.

Solution: use a counting map where the keys are the highest set bits. Populate this in linear time.

Then, process the keys. Say there are n total integers, and some key has r integers (with the same highest set bit). Then that key adds r * (n-r) to the count of pairs where xor > and. That is, each of r integers can be paired with each of (n-r) with a different highest set bit.

This double-counts everything, so divide by 2 at the end.


Example:

Say we have 8 integers, 3 of which have the third bit as their highest set bit, 3 the fourth, and 2 the fifth.

So per my algorithm, we have three buckets of sizes 3, 3, and 2, and a solution of [3*(8-3) + 3*(8-3) + 2*(8-2)] / 2 = 42 / 2 = 21.

Here’s a more detailed explanation of the approach:

Any two integers within the same bucket have a higher value under the AND operation than the XOR operation because the AND operation preserves the max set bit, and the XOR operation turns it to 0.

Now take two integers in separate buckets. One of them has a higher max set bit than the other. That means that the max set bit between the two numbers appears in one but not the other, so becomes a 0 under the AND operation, but is preserved under the XOR operation, thus XOR results in a higher max set bit.

The number of pairs where XOR yields a higher result than AND is exactly the number of pairs which are not both in the same bucket.

Say we have n integers total, and r in some bucket. Each of the r integers in that bucket can be paired with any of the (n-r) integers in the other buckets and not with any of the r-1 integers in the same bucket, contributing r * (n-r) to the count of pairs where XOR yields a higher integer than AND.

However, this counts every pair that contributes to our count exactly twice. E.g., when we’re comparing 1001 and 110, our analysis of both the bucket with the 4th and 3rd highest set bit being 1 will be incremented by 1 for this pair. Thus at the end we have to divide by 2.


Further example:

Here are all integers with the third bit as their highest set bit: 4, 5, 6, 7, or 100, 101, 110, and 111 in binary.

Here are all with the second bit as their highest set bit: 2, 3 or 10, 11 in binary.

Take any pair with the same highest set bit, arbitrarily I’ll choose 6 and 7. AND(110, 111) = 110 = 6. XOR(110, 111) = 001. So the AND operation produces a higher result than XOR. In every case, XOR will convert the highest set bit from 1 to 0, and AND will keep it at 1, so in every case AND will result in a higher result than XOR.

Taking pairs from separate bucket, whichever bit is the highest set bit among the pair is only set in one of the two integers (because this is what we’re bucketing by), so under AND that bit becomes 0, and under XOR it remains 1. Since the XOR operation leaves the output with a bigger highest-set-bit than AND, the resulting integer is higher under XOR.

Answered By: Dave

This uses (effectively) the same algorithm as you, so it’s still O(n^2), but you can speed up the operation using numpy:

  • np.bitwise_xor performs the bitwise xor operation on two arrays
  • np.bitwise_and performs the bitwise and operation on two arrays
  • Giving a row-vector and a column-vector to these functions allows numpy to broadcast the result to a square matrix.
  • Comparing the resulting matrices, we get a boolean array. We only need the lower triangle of this matrix. Since we know that a ^ a == 0, we can simply sum the entire array and divide its result by 2 for the answer.
import numpy as np

def npy(nums):
    xor_arr = np.bitwise_xor(nums, nums[:, None])
    and_arr = np.bitwise_and(nums, nums[:, None])

    return (xor_arr > and_arr).sum() // 2

You could also skip numpy altogether and use numba to JIT-compile your own code before it is run.

import numba

@numba.njit
def nba(array):
    n = len(array)
    ans = 0
    for i in range(0, n):
        p1 = array[i]
        for j in range(i, n):
            p2 = array[j]
            if p1 ^ p2 > p1 & p2:
                ans +=1
    return ans

Finally, here’s my implementation of Dave’s algorithm:

from collections import defaultdict
def new_alg(array):
    msb_num_count = defaultdict(int)
    for num in array:
        msb = len(bin(num)) - 2 # This was faster than right-shifting until zero
        msb_num_count[msb] += 1 # Increment the count of numbers that have this MSB
    
    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = len(array)
    for group_len in msb_num_count.values():
        cnt += group_len * (len_all_groups - group_len)

    return cnt // 2

And, as a numba-compatible function. I needed to define a get_msb since numba.njit won’t handle builtin python functions

@numba.njit
def get_msb(num):
    msb = 0
    while num:
        msb += 1
        num = num >> 1
    return msb

@numba.njit
def new_alg_numba(array):
    msb_num_count = {}
    for num in array:
        msb = get_msb(num)
        if msb not in msb_num_count:
            msb_num_count[msb] = 0
        msb_num_count[msb] += 1

    # Now, for each number, the count will be the sum of the numbers in all other groups
    cnt = 0
    len_all_groups = len(array)
        
    for grp_len in msb_num_count.values():
        cnt += grp_len * (len_all_groups - grp_len)

    return cnt // 2

Comparing runtimes, we see that the numba approach is significantly faster than the numpy approach, which is itself faster than looping in python.

The linear-time algorithm given by Dave is faster than the numpy approach to begin with, and it starts to get faster than the numba-compiled code for inputs > ~1000 elements. The numba-compiled version of this approach is even faster — it outpaces the numba-compiled loopy at ~100 elements.

Kelly’s excellent implementation of Dave’s algorithm is on par with the numba-version of my implementation for larger inputs

enter image description here

(Your implementation is labelled "loopy". Other legend labels in the plot are the same as function names in my answer above. Kelly’s implementation is labelled "kelly")

Answered By: Pranav Hosangadi

Another implementation of Dave’s algorithm:

from collections import Counter

def solve(array):
    ctr = Counter(map(int.bit_length, array))
    n = len(array)
    return sum(r * (n-r) for r in ctr.values()) // 2
Answered By: Kelly Bundy