Is there a function that can apply NumPy's broadcasting rules to a list of shapes and return the final shape?

Question:

This is not a question about how broadcasting works (i.e., it’s not a duplicate of these questions).

I would just like find a function that can apply NumPy’s broadcasting rules to a list of shapes and return the final shape, for example:

>>> broadcast_shapes([6], [4, 2, 3, 1], [2, 1, 1])
[4, 2, 3, 6]

Thanks!

Asked By: MiniQuark

||

Answers:

In [120]: shapes = [6], [4, 2, 3, 1], [2, 1, 1]                                 
In [121]: arrs = np.broadcast_arrays(*[np.empty(shape,int) for shape in shapes])
     ...:                                                                       
In [122]: [a.shape for a in arrs]                                               
Out[122]: [(4, 2, 3, 6), (4, 2, 3, 6), (4, 2, 3, 6)]

In [124]: np.lib.stride_tricks._broadcast_shape(*[np.empty(shape,int) for shape 
     ...: in shapes])                                                           
Out[124]: (4, 2, 3, 6)

In [131]: np.broadcast(*[np.empty(shape,int) for shape in shapes]).shape        
Out[131]: (4, 2, 3, 6)

The second times quite a bit faster, 4.79 µs vs 42.4 µs. Third is a tad faster.

As I first commented, I started with broadcast_arrays, and looked at the code. That me to _broadcast_shape, and then to np.broadcast.

Answered By: hpaulj

Here is a simple implementation, just in case someone needs it (it might help understand broadcasting). I would prefer using a NumPy function though.

def broadcast_shapes(*shapes):
    max_rank = max([len(shape) for shape in shapes])
    shapes = [[1] * (max_rank - len(shape)) + shape for shape in shapes]
    final_shape = [1] * max_rank
    for shape in shapes:
        for dim, size in enumerate(shape):
            if size != 1:
                final_size = final_shape[dim]
                if final_size == 1:
                    final_shape[dim] = size
                elif final_size != size:
                    raise ValueError("Cannot broadcast these shapes")
    return final_shape

Edit

I timed this function against a few other answers, and it turned out to be the fastest (edit, Paul Panzer wrote an even faster function, see his answer, I added it to the list below):

%timeit bs_pp(*shapes) # Peter Panzer's answer
2.33 µs ± 10.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit broadcast_shapes1(*shapes)  # this answer
4.21 µs ± 11.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit broadcast_shapes2(*shapes) # my other answer with shapes.max(axis=0)
12.8 µs ± 67.3 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit broadcast_shapes3(*shapes) # user2357112's answer
18 µs ± 26.9 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

%timeit broadcast_shapes4(*shapes) # hpaulj's answer
18.1 µs ± 263 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
Answered By: MiniQuark

As of NumPy 1.20, there’s a numpy.broadcast_shapes function that does exactly what you’re looking for. (It’s documented as taking tuples instead of lists, so you should probably pass it tuples just to be safe, but it accepts lists in practice.)

In [1]: import numpy

In [2]: numpy.broadcast_shapes((6,), (4, 2, 3, 1), (2, 1, 1))
Out[2]: (4, 2, 3, 6)

For previous versions, you could broadcast a single 0-dimensional array to each target shape and then broadcast all the results against each other:

def broadcast_shapes(*shapes):
    base = numpy.array(0)
    broadcast1 = [numpy.broadcast_to(base, shape) for shape in shapes]
    return numpy.broadcast(*broadcast1).shape

This avoids allocating large amounts of memory for large shapes. Needing to create arrays at all feels kind of silly, though.

Answered By: user2357112

Assuming the shapes can actually be broadcasted, then this works:

def broadcast_shapes(*shapes):
    max_rank = max([len(shape) for shape in shapes])
    shapes = np.array([[1] * (max_rank - len(shape)) + shape
                      for shape in shapes])
    shapes[shapes==1] = -1
    final_shape = shapes.max(axis=0)
    final_shape[final_shape==-1] = 1
    return final_shape

If you assume there’s no empty dimension, then the -1 hack is not necessary:

def broadcast_shapes(*shapes):
    max_rank = max([len(shape) for shape in shapes])
    shapes = np.array([[1] * (max_rank - len(shape)) + shape
                      for shape in shapes])
    return shapes.max(axis=0)
Answered By: MiniQuark

Here is another direct implementation which happens to beat the others on the example. Honorable mention goes to @hpaulj’s with @Warren Weckesser’s hack which is almost as fast and much more concise:

def bs_pp(*shapes):
    ml = max(shapes, key=len)
    out = list(ml)
    for l in shapes:
        if l is ml:
            continue
        for i, x in enumerate(l, -len(l)):
            if x != 1 and x != out[i]:
                if out[i] != 1:
                    raise ValueError
                out[i] = x
    return (*out,)

def bs_mq1(*shapes):
    max_rank = max([len(shape) for shape in shapes])
    shapes = [[1] * (max_rank - len(shape)) + shape for shape in shapes]
    final_shape = [1] * max_rank
    for shape in shapes:
        for dim, size in enumerate(shape):
            if size != 1:
                final_size = final_shape[dim]
                if final_size == 1:
                    final_shape[dim] = size
                elif final_size != size:
                    raise ValueError("Cannot broadcast these shapes")
    return (*final_shape,)

import numpy as np

def bs_mq2(*shapes):
    max_rank = max([len(shape) for shape in shapes])
    shapes = np.array([[1] * (max_rank - len(shape)) + shape
                      for shape in shapes])
    shapes[shapes==1] = -1
    final_shape = shapes.max(axis=0)
    final_shape[final_shape==-1] = 1
    return (*final_shape,)

def bs_hp_ww(*shapes):
    return np.broadcast(*[np.empty(shape + [0,], int) for shape in shapes]).shape[:-1]

L = [6], [4, 2, 3, 1], [2, 1, 1]

from timeit import timeit

print('pp:       ', timeit(lambda: bs_pp(*L), number=10_000)/10)
print('mq 1:     ', timeit(lambda: bs_mq1(*L), number=10_000)/10)
print('mq 2:     ', timeit(lambda: bs_mq2(*L), number=10_000)/10)
print('hpaulj/ww:', timeit(lambda: bs_hp_ww(*L), number=10_000)/10)

assert bs_pp(*L) == bs_mq1(*L) and bs_pp(*L) == bs_mq2(*L) and bs_pp(*L) == bs_hp_ww(*L)

Sample run:

pp:        0.0021552839782088993
mq 1:      0.00398325570859015
mq 2:      0.01497043427079916
hpaulj/ww: 0.003267909213900566
Answered By: Paul Panzer
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.