Numpy argmax – random tie breaking

Question:

In numpy.argmax function, tie breaking between multiple max elements is so that the first element is returned.
Is there a functionality for randomizing tie breaking so that all maximum numbers have equal chance of being selected?

Below is an example directly from numpy.argmax documentation.

>>> b = np.arange(6)
>>> b[1] = 5
>>> b
array([0, 5, 2, 3, 4, 5])
>>> np.argmax(b) # Only the first occurrence is returned.
1

I am looking for ways so that 1st and 5th elements in the list are returned with equal probability.

Thank you!

Asked By: Jenna Kwon

||

Answers:

Use np.random.choice

np.random.choice(np.flatnonzero(b == b.max()))

Let’s verify for an array with three max candidates –

In [298]: b
Out[298]: array([0, 5, 2, 5, 4, 5])

In [299]: c=[np.random.choice(np.flatnonzero(b == b.max())) for i in range(100000)]

In [300]: np.bincount(c)
Out[300]: array([    0, 33180,     0, 33611,     0, 33209])
Answered By: Divakar

In the case of a multi-dimensional array, choice won’t work.

An alternative is

def randargmax(b,**kw):
  """ a random tie-breaking argmax"""
  return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

If for some reason generating random floats is slower than some other method, random.random can be replaced with that other method.

Answered By: Manux

Easiest way is

np.random.choice(np.where(b == b.max())[0])
Answered By: shyam padia

Since the accepted answer may not be obvious, here is how it works:

  • b == b.max() will return an array of booleans, with values of true where items are max and values of false for other items
  • flatnonzero() will do two things: ignore the false values (nonzero part) then return indices of true values. In other words, you get an array with indices of items matching the max value
  • Finally, you pick a random index from the array
Answered By: bluephoton

Additional to @Manux’s answer,

Changing b.max() to np.amax(b,**kw, keepdims=True) will let you do it along axes.

def randargmax(b,**kw):
    """ a random tie-breaking argmax"""
    return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw)

randargmax(b,axis=None) 
Answered By: asrvnon

Here is a comparison between the two main solutions by @divakar and @shyam-padia :

method (1) – using np.where

np.random.choice(np.where(b == b.max())[0])

method (2) – using np.flatnonzero

np.random.choice(np.flatnonzero(b == b.max())

Code

Here is the code I wrote for the comparison:

def method1(b, bmax,):
    return np.random.choice(np.where(b == bmax)[0])

def method2(b, bmax):
    return np.random.choice(np.flatnonzero(b == bmax))

def time_it(n):
    b = np.array([1.0, 2.0, 5.0, 5.0, 0.4, 0.1, 5.0, 0.3, 0.1])
    bmax = b.max()

    start = time.perf_counter()
    for i in range(n):
        method1(b, bmax)
    elapsed1 = time.perf_counter() - start
    start = time.perf_counter() 
    for i in range(n):
        method2(b, bmax)
    elapsed2 = time.perf_counter() - start

    print(f'method1 time: {elapsed1} - method2 time: {elapsed2}')
    return elapsed1, elapsed2

Results

The following figure shows the computation time for running each method for [100, 1000, 10000, 100000, 1000000] iterations where x-axis represents number of iterations, y-axis shows time in seconds. It can be seen that np.where performs better than np.flatnonzero when number of iterations increases. Note that the x-axis has a logarithmic scale.

enter image description here

To show how the two methods compare in the lower iteration, we can re-plot the previous results by making the y-axis being a logarithmic scale. We can see that np.where stays always better than np.flatnonzero.

enter image description here

Answered By: NKN
Categories: questions Tags: ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.