Nested gridspec alignment

Question:

When using gridspecs, I find it difficult to align different nested gridspecs. I often use gridspecs for figures where most or all subplots have a fixed aspect ratio (e.g., to display images).

A minimal example would be the following plot, where two square images are displayed next to 4 smaller images in a nested subplot:

import matplotlib.pyplot as plt 
import numpy as np

n_cols = 3
fig = plt.figure(1, figsize=(6, 6 / n_cols * 1.5))
gs = fig.add_gridspec(1, n_cols)

test_img = np.ones((64, 64, 3)) * np.linspace(0.3, 1, 64)[:, None] # simple, square test image
for col in range(n_cols - 1):
    ax = fig.add_subplot(gs[col])
    ax.imshow(test_img)

gs_sub = gs[-1].subgridspec(2, 2, wspace=0.02, hspace=0.02)
for i in range(4):
    ax = fig.add_subplot(gs_sub[i])
    ax.imshow(test_img)

# use tight layout to remove excess white space
gs.tight_layout(fig, rect=[0, 0, 1, 1], pad=0.001)
gs.update(wspace=0.025, hspace=0.0)

This results in the following plot:

Plot 1

As you can see, the smaller images vertically use more space than the larger ones. I guess the nested gridspec tries to use all the available space, and is in no way restricted to match the two larger images on the left. On the other hand, it all aligns fine for plots with a flexible aspect ratio (e.g., line plots), as then the aspect ratio of the subplots stretches automatically:

Plot 2

(don’t mind the overlapping axis ticks, it’s easy to add more space if needed).

I can also oftentimes get things to work out okay by scaling the height of the plot or even playing with height/width ratios. In the above plot, the result can be improved by removing the arbitrary scale factor "1.5" that is applied to the plot height. However, this isn’t a good solution as it often requires a lot of manual experimentation and is rarely perfect (especially for more complex layouts).

Are there better ways of doing this? Is there a way to inform the nested gridspec of the desired alignment? Ideally, I would want to control the nested gridspec to match the height of the other subplots, rather than using up all the available space.

Asked By: delio

||

Answers:

The issue with displaying images in matplotlib plots is that the figure is usually over constrained. In your example, you set a specific aspect ratio for your figure, and then specify an arbitrary spacing between subplots. Contrary to plots, images have a fixed aspect ratio, leading in inconsistent figure parameters, so one of the properties you try to set ends up being violated.

The solution is to make sure the parameters you set are consistent. Since the aspect ratio of your images is fixed, you can choose which is more important to you, the spacing between subplots or the global aspect ratio of the figure, and constrain accordingly.

In the following I will assume that the aspect ratio of the figure itself does not matter so much, so we will set the spacing between subplots, and derive what the figsize should be to respect the constraints.

According to the documentation, the wspace and hspace parameters are expressed as fraction of the average axis width/height. Therefore, for a figure plotting a grid of n_rows x n_cols images of size wxh, the total width is

W = n_cols * w + (n_cols-1) * w_space * w

Likewise, the total height is:

H = n_rows * h + (n_rows-1) * h_space * h

We can then use the ratio W / H as the aspect ratio for the main figure.

For simplicity, I assumed all images have the same size. One can derive a similar formula for arbitrary width/height ratios, but the idea remains the same.

In order to get this to work with subgridspecs, you only need to do this first for the subgridspecs, which will give you their size in the main figure, which you can in turn use to compute the main figure aspect ratio.

Below is an illustration of this idea with a slightly modified version of your example figure. I implemented a gridspec_aspect function that computes the actual aspect of a (sub)figure given arbitrary width/height ratios, and w/hspaces.


import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np

def gridspec_aspect(n_rows, n_cols, w, h, wspace=0, hspace=0):
    if isinstance(w, int):
        Ws = n_cols * w
    elif isinstance(w, list) or isinstance(w, tuple):
        Ws = sum(w)

    if isinstance(h, int):
        Hs = n_rows * h
    elif isinstance(h, list) or isinstance(h, tuple):
        Hs = sum(h)

    w_spacing = wspace * Ws / n_cols
    h_spacing = hspace * Hs / n_rows

    return (Ws + (n_cols - 1) * w_spacing) / (Hs + (n_rows - 1) * h_spacing)

n_cols = 3
n_rows = 1
test_img = np.ones((150, 200, 3)) * np.linspace(0.3, 0.9, 200)[:, None] # simple test image

# Image aspect ratio
h,w = test_img.shape[:2]
r = w / h

# Spacing in the inner gridspec
inner_wspace = 0.05
inner_hspace = inner_wspace * r # same vertical spacing as horizontal spacing
inner_aspect = gridspec_aspect(2, 2, w, h, inner_wspace, inner_hspace)

# Spacing in the main griddpec
outer_wspace = 0.1
outer_aspect = gridspec_aspect(n_rows, n_cols, [r, r, inner_aspect], 1, outer_wspace)

fig = plt.figure(1, figsize=(10, 10 / outer_aspect))
gs = fig.add_gridspec(1, n_cols, wspace=outer_wspace, width_ratios=[r, r, inner_aspect])
inner_gs = gridspec.GridSpecFromSubplotSpec(2, 2, subplot_spec=gs[2], wspace=inner_wspace, hspace=inner_hspace)

for col in range(n_cols - 1):
    ax = fig.add_subplot(gs[col])
    ax.imshow(test_img)
    ax.axis('off')

for i in range(2):
    for j in range(2):
        ax = fig.add_subplot(inner_gs[i,j])
        ax.imshow(test_img)
        ax.axis('off')

Which generates the following figure:
enter image description here

Answered By: bathal