Find the minimum number of steps to half the sum of elements in a list where each step halves an item in the list in O(N)

Question:

I came across an interview question that went like this:

There are factories in an area which produce a pollutive gas and filters are to be installed at each factory to reduce the pollution. Each filter installed would half the pollution in that factory. Each factory can have multiple filters. There is a list of N integers representing the level of pollution in each of the N factories in the area. Find the minimum number of filters needed to half the overall pollution.

E.g. – Let [3, 5, 6, 1, 18] be the list of pollution levels in 5 factories

  • Overall pollution = 3+5+6+1+18 = 33 (target is 33/2 = 16.5)

  • Install a filter in factory given by index=4 — > pollution levels will be [3, 5, 6, 1, 9]

  • Install a filter in factory given by index=4 — > pollution levels will be [3, 5, 6, 1, 4.5]

  • Install a filter in factory given by index=2 — > pollution levels will be [3, 5, 3, 1, 4.5]

  • Need 3 filters minimum to half the overall pollution.

N is an integer within the range [1….30,000]. Each element in the list is an integer within the range [0….70,000]

The solution I came up with for this was simple:
Find the max in the list and half in every time until the sum is <=target

def solution(A):
    total = sum(A)
    target = total/2
    count = 0
    while total>target:
        count+=1
        max_p = max(A)
        total-= max_p/2
        A.remove(max_p)
        A.append(max_p/2)
    return count

This works well, except that the time complexity seems to be O(N^2). Can someone please suggest an approach to solve this with less time complexity (preferably O(N))?

Asked By: Yumna Albar

||

Answers:

Maybe you could utilize a max heap to retrieve the worst factory more efficiently than you are right now, i.e., using a heap would allow for an O(N log N) solution:

import heapq


def filters_required(factories: list[int]) -> int:
    """Returns minimum filters required to halve pollution."""
    current_pollution = sum(factories)
    goal_pollution = current_pollution / 2
    filters = 0
    factory_pollution_max_heap = [-p for p in factories]
    heapq.heapify(factory_pollution_max_heap)
    while current_pollution > goal_pollution:
        worst_factory = heapq.heappop(factory_pollution_max_heap)
        pollution = worst_factory / 2
        current_pollution += pollution  # Use += since pollution will be a negative number.
        heapq.heappush(factory_pollution_max_heap, pollution)
        print('DEBUG:', [-p for p in factory_pollution_max_heap], current_pollution)
        filters += 1
    return filters


def main() -> None:
    print(f'{filters_required(factories=[3, 5, 6, 1, 18]) = }')


if __name__ == '__main__':
    main()

Output:

DEBUG: [9.0, 6, 3, 1, 5] 24.0
DEBUG: [6, 5, 3, 1, 4.5] 19.5
DEBUG: [5, 4.5, 3, 1, 3.0] 16.5
filters_required(factories=[3, 5, 6, 1, 18]) = 3
Answered By: Sash Sinha

My O(N log N) answer in Java:

public static int pollution(double[] factories) {
    int filters = 0;
    double half = 0, currSum = 0, temp = 0;
    PriorityQueue<Double> pq = new PriorityQueue<>(Collections.reverseOrder());

    for (double i : factories) {
      pq.add(i);
      half += i;
    }

    currSum = half;
    half = half / 2;

    while (currSum > half) {
      temp = pq.poll();
      currSum -= temp / 2;
      pq.add(temp / 2);
      filters++;
    }

    return filters;
}
Answered By: Eldar Bril

Wrote Main code for above code to ease the testing..

import java.util.Arrays;
import java.util.Collections;
import java.util.PriorityQueue;

public final class PCFiltersCount
{
    public static int pollution(final double[] aFactories)
    {
    int lFilters = 0;
    double lHalf = 0, lCurrSum = 0, lTemp = 0;

    final PriorityQueue<Double> lPriorityQueue = new PriorityQueue<>(Collections.reverseOrder());
    for (double i : aFactories)
    {
        lPriorityQueue.add(i);
        lHalf += i;
    }

    lCurrSum = lHalf;
    lHalf = lHalf / 2;

    while (lCurrSum > lHalf)
    {
        lTemp = lPriorityQueue.poll();
        lCurrSum -= lTemp / 2;
        lPriorityQueue.add(lTemp / 2);
        lFilters++;
    }

    return lFilters;
    }

    public static void main(final String[] args)
    {
    double[][][] l = {
        {{15.0, 19, 8, 1}, {3}},
        {{10, 10}, {2}},
        {{3, 0, 51}, {2}},
        {{9.0, 6, 3, 1, 5}, {4}},
        {{6, 5, 3, 1, 4.5}, {5}},
        {{5, 4.5, 3, 1, 3.0}, {5}},
        };

    for (final double[][] lFactoryData : l)
    {
        int lResult = pollution(lFactoryData[0]);
        System.out.println("for Input: " + Arrays.toString(lFactoryData[0]) + " = " + lResult);
        assert lResult == lFactoryData[1][0];
    }
    }
}
Answered By: Raja Nagendra Kumar

If anyone wondering for solution in Javascript, here’s my take on it.

function filtersRequired(factories) {
  // Returns minimum filters required to halve pollution.
  let currentPollution = factories.reduce((a, b) => a + b, 0);
  const goalPollution = currentPollution / 2;
  let filters = 0;
  const factoryPollutionMaxHeap = factories.map((p) => -p);
  makeHeap(factoryPollutionMaxHeap);

  while (currentPollution > goalPollution) {
    const worstFactory = extractMin(factoryPollutionMaxHeap);
    const pollution = worstFactory / 2;
    currentPollution += pollution;
    insert(factoryPollutionMaxHeap, pollution);
    console.log(
      "DEBUG:",
      factoryPollutionMaxHeap.map((p) => -p),
      currentPollution
    );
    filters += 1;
  }

  return filters;
}

function makeHeap(arr) {
  for (let i = Math.floor(arr.length / 2); i >= 0; i--) {
    heapify(arr, i);
  }
}

function heapify(arr, i) {
  const left = 2 * i + 1;
  const right = 2 * i + 2;
  let largest = i;

  if (left < arr.length && arr[left] < arr[largest]) {
    largest = left;
  }

  if (right < arr.length && arr[right] < arr[largest]) {
    largest = right;
  }

  if (largest !== i) {
    [arr[i], arr[largest]] = [arr[largest], arr[i]];
    heapify(arr, largest);
  }
}

function extractMin(arr) {
  if (arr.length <= 0) {
    return undefined;
  }
  if (arr.length === 1) {
    return arr.pop();
  }

  const root = arr[0];
  arr[0] = arr.pop();
  heapify(arr, 0);
  return root;
}

function insert(arr, val) {
  arr.push(val);
  let currentIdx = arr.length - 1;
  while (currentIdx > 0) {
    const parentIdx = Math.floor((currentIdx - 1) / 2);
    if (arr[parentIdx] <= val) {
      break;
    }
    [arr[currentIdx], arr[parentIdx]] = [arr[parentIdx], arr[currentIdx]];
    currentIdx = parentIdx;
  }
}

function main() {
  console.log(`Filters Required: ${filtersRequired([5, 19, 8, 1])}`);
}

main();
Answered By: lazzy_ms