Speed up random weighted choice without replacement in python

Question:

I want to sample ~10⁷ times from a population of ~10⁷ integers without replacements and with weights, each time picking 10 elements. After each sampling I change the weights. I have timed two approaches (python3 and numpy) in the following script. Both approaches seem painfully slow to me, do you see a way of speeding it up?

import numpy as np
import random

@profile
def test_choices():
    population = list(range(10**7))
    weights = np.random.uniform(size=10**7)
    np_weights = np.array(weights)

    def numpy_choice():
        np_w = np_weights / sum(np_weights)
        c = np.random.choice(population, size=10, replace=False, p=np_w)

    def python_choice():
        c = []
        while len(c) < 10:
            c += random.choices(population=population, weights=weights, k=10 - len(c))
            c = list(set(c))

    for i in range(10**1):

        numpy_choice()
        python_choice()

        add_weight = np.random.uniform()
        random_element = random.randint(0, 10**7)
        weights[random_element] += add_weight
        np_weights[random_element] += add_weight


test_choices()

With the timer result:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
    24        10   20720062.0 2072006.2     56.6          numpy_choice()
    25        10   15593925.0 1559392.5     42.6          python_choice()
Asked By: Nik

||

Answers:

You can try something like this. I have accelerated my function with Numba but in my tests it is faster also without that.

import numpy as np
import numba as nb

@nb.njit
def numba_choice(population, weights, k):
    # Get cumulative weights
    wc = np.cumsum(weights)
    # Total of weights
    m = wc[-1]
    # Arrays of sample and sampled indices
    sample = np.empty(k, population.dtype)
    sample_idx = np.full(k, -1, np.int32)
    # Sampling loop
    i = 0
    while i < k:
        # Pick random weight value
        r = m * np.random.rand()
        # Get corresponding index
        idx = np.searchsorted(wc, r, side='right')
        # Check index was not selected before
        # If not using Numba you can just do `np.isin(idx, sample_idx)`
        for j in range(i):
            if sample_idx[j] == idx:
                continue
        # Save sampled value and index
        sample[i] = population[idx]
        sample_idx[i] = population[idx]
        i += 1
    return sample

Here is a quick comparison

def python_choice(population, weights, k):
    c = []
    while len(c) < 10:
        c += random.choices(population=population, weights=weights, k=10 - len(c))
        c = list(set(c))
    return c

def numpy_choice(population, weights, k):
    w = weights / weights.sum()
    return np.random.choice(population, size=k, replace=False, p=w)

# Test
np.random.seed(0)
population = np.random.randint(100, size=1_000_000)
weights = np.random.rand(len(population))
k = 10
print(python_choice(population, weights, k))
# [96, 99, 90, 46, 78, 16, 17, 22, 58, 30]
print(numpy_choice(population, weights, k))
# [ 9 61  1 18 41 89 55  4 53 40]
print(numba_choice(population, weights, k))
# [66 82 91 62  9 56 71 14 32 26]

%timeit python_choice(population, weights, k)
# 198 ms ± 19.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
%timeit numpy_choice(population, weights, k)
# 13.4 ms ± 65.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%timeit numba_choice(population, weights, k)
# 2.08 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

EDIT: Here is how it could go without Numba:

import numpy as np

def loop_choice(population, weights, k):
    wc = np.cumsum(weights)
    m = wc[-1]
    sample = np.empty(k, population.dtype)
    sample_idx = np.full(k, -1, np.int32)
    i = 0
    while i < k:
        r = m * np.random.rand()
        idx = np.searchsorted(wc, r, side='right')
        if np.isin(idx, sample_idx):
            continue
        sample[i] = population[idx]
        sample_idx[i] = population[idx]
        i += 1
    return sample

# Setup from before...
%timeit loop_choice(population, weights, k)
# 3.55 ms ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

EDIT: Just a small test to check the samples are adjusted to the weights:

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
n = 200
population = np.arange(n)
weights = np.sin(np.linspace(0, 2 * np.pi, n)) + 1
k = 15
r = 1600
a = np.zeros(n, np.int32)
for _ in range(r):
    c = numba_choice(population, weights, k)
    np.add.at(a, c, 1)
plt.figure()
plt.plot(weights / weights.sum(), label='Weights')
plt.plot(a / (k * r), label='Samples')
plt.legend()
plt.tight_layout()
plt.show()

Result:

Sample result

Answered By: jdehesa

This is just a comment on jdhesa’s answer. The question was if it is useful to consider the case where only one weight is incresed -> Yes it is!

Example

@nb.njit(parallel=True)
def numba_choice_opt(population, weights, k,wc,b_full_wc_calc,ind,value):
    # Get cumulative weights
    if b_full_wc_calc:
        acc=0
        for i in range(weights.shape[0]):
            acc+=weights[i]
            wc[i]=acc
    #Increase only one weight (faster than recalculating the cumulative  weight)
    else:
        weights[ind]+=value
        for i in nb.prange(ind,wc.shape[0]):
            wc[i]+=value

    # Total of weights
    m = wc[-1]
    # Arrays of sample and sampled indices
    sample = np.empty(k, population.dtype)
    sample_idx = np.full(k, -1, np.int32)
    # Sampling loop
    i = 0
    while i < k:
        # Pick random weight value
        r = m * np.random.rand()
        # Get corresponding index
        idx = np.searchsorted(wc, r, side='right')
        # Check index was not selected before
        # If not using Numba you can just do `np.isin(idx, sample_idx)`
        for j in range(i):
            if sample_idx[j] == idx:
                continue
        # Save sampled value and index
        sample[i] = population[idx]
        sample_idx[i] = population[idx]
        i += 1
    return sample

Example

np.random.seed(0)
population = np.random.randint(100, size=1_000_000)
weights = np.random.rand(len(population))
k = 10
wc = np.empty_like(weights)

#Initial calculation 
%timeit numba_choice_opt(population, weights, k,wc,True,0,0)
#1.41 ms ± 9.21 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

#Increase weight[100] by 3 and calculate
%timeit numba_choice_opt(population, weights, k,wc,False,100,3)
#213 µs ± 6.06 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

#For comparison
#Please note that it is the memory allcocation of wc which makes
#it so much slower than the initial calculation from above
%timeit numba_choice(population, weights, k)
#4.23 ms ± 64.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
Answered By: max9111

I think there might be bug in both of these implementations — for me it seems continue is not actually implementing sampling with replacement (it doesn’t seem to have an effect; I still get duplicate indices).

Building on @jdehesa’s answer, here’s a version with (optional) sampling without replacement (note: it returns the indices rather than samples from an array, but this is an easy change to make).

@nb.njit
def nb_choice(max_n, k=1, weights=None, replace=False):
    '''
    Choose k samples from max_n values, with optional weights and replacement.
    
    Args:
        max_n (int): the maximum index to choose
        k (int): number of samples
        weights (array): weight of each index, if not uniform
        replace (bool): whether to sample with replacement
    '''
    # Get cumulative weights
    if weights is None:
        weights = np.full(int(max_n), 1.0)
    cumweights = np.cumsum(weights)
    
    maxweight = cumweights[-1] # Total of weights
    inds = np.full(k, -1, dtype=np.int64) # Arrays of sample and sampled indices
        
    # Sample
    i = 0
    while i < k:
        
        # Find the index
        r = maxweight * np.random.rand() # Pick random weight value
        ind = np.searchsorted(cumweights, r, side='right') # Get corresponding index
        
        # Optionally sample without replacement
        found = False
        if not replace:
            for j in range(i):
                if inds[j] == ind:
                    found = True
                    continue
        if not found:
            inds[i] = ind
            i += 1

    return inds

Example:

n = 1_000_000
population = np.arange(n)
weights = np.random.rand(n)
k = 10
samples = population[nb_choice(n, k, weights)]

%timeit nb_choice(n, k, weights, replace=False)
446 µs ± 1.16 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Answered By: Cliff Kerr
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.