Row and column headers in matplotlib's subplots

Question:

What’s the best practise to add a row and a column header to a grid of subplots generated in a loop in matplotlib? I can think of a couple, but not particularly neat:

  1. For columns, with a counter to your loop you can use set_title() for the first row only. For rows this doesn’t work. You would have to draw text outside of the plots.
  2. You add an extra row of subplots on top and an extra column of subplots on the left, and draw text in the middle of that subplot.

Can you suggest a better alternative?

enter image description here

Asked By: gozzilli

||

Answers:

There are several ways to do this. The easy way is to exploit the y-labels and titles of the plot and then use fig.tight_layout() to make room for the labels. Alternatively, you can place additional text in the right location with annotate and then make room for it semi-manually.


If you don’t have y-labels on your axes, it’s easy to exploit the title and y-label of the first row and column of axes.

import matplotlib.pyplot as plt

cols = ['Column {}'.format(col) for col in range(1, 4)]
rows = ['Row {}'.format(row) for row in ['A', 'B', 'C', 'D']]

fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(12, 8))

for ax, col in zip(axes[0], cols):
    ax.set_title(col)

for ax, row in zip(axes[:,0], rows):
    ax.set_ylabel(row, rotation=0, size='large')

fig.tight_layout()
plt.show()

enter image description here


If you do have y-labels, or if you prefer a bit more flexibility, you can use annotate to place the labels. This is more complicated, but allows you to have individual plot titles, ylabels, etc in addition to the row and column labels.

import matplotlib.pyplot as plt
from matplotlib.transforms import offset_copy


cols = ['Column {}'.format(col) for col in range(1, 4)]
rows = ['Row {}'.format(row) for row in ['A', 'B', 'C', 'D']]

fig, axes = plt.subplots(nrows=4, ncols=3, figsize=(12, 8))
plt.setp(axes.flat, xlabel='X-label', ylabel='Y-label')

pad = 5 # in points

for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size='large', ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')

fig.tight_layout()
# tight_layout doesn't take these labels into account. We'll need 
# to make some room. These numbers are are manually tweaked. 
# You could automatically calculate them, but it's a pain.
fig.subplots_adjust(left=0.15, top=0.95)

plt.show()

enter image description here

Answered By: Joe Kington

The above answer works. Just not that in the second version of the answer, you have:

for ax, row in zip(axes[:,0], rows):
    ax.annotate(col, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad,0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')

instead of:

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row,xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad,0),                    
                xycoords=ax.yaxis.label, textcoords='offset points',
                size='large', ha='right', va='center')
Answered By: Alan Shteyman

Based on Joe Kington’s answer, I put up a function that can be reused across a code base:

It accepts as arguments:

  • fig : The figure which contains the axes to work on
  • row_headers, col_headers : a sequence of strings to be headers
  • row_pad, col_pad: int value to adjust padding
  • rotate_row_headers: whether to rotate by 90° the row headers
  • **text_kwargs: forwarded to ax.annotate(...)

Function here, examples below:

import numpy as np

def add_headers(
    fig,
    *,
    row_headers=None,
    col_headers=None,
    row_pad=1,
    col_pad=5,
    rotate_row_headers=True,
    **text_kwargs
):
    # Based on https://stackoverflow.com/a/25814386

    axes = fig.get_axes()

    for ax in axes:
        sbs = ax.get_subplotspec()

        # Putting headers on cols
        if (col_headers is not None) and sbs.is_first_row():
            ax.annotate(
                col_headers[sbs.colspan.start],
                xy=(0.5, 1),
                xytext=(0, col_pad),
                xycoords="axes fraction",
                textcoords="offset points",
                ha="center",
                va="baseline",
                **text_kwargs,
            )

        # Putting headers on rows
        if (row_headers is not None) and sbs.is_first_col():
            ax.annotate(
                row_headers[sbs.rowspan.start],
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - row_pad, 0),
                xycoords=ax.yaxis.label,
                textcoords="offset points",
                ha="right",
                va="center",
                rotation=rotate_row_headers * 90,
                **text_kwargs,
            )

Here is an example of using it using on a standard grid (no axes spans multiple rows / cols):

import random
import matplotlib.pyplot as plt

mosaic = [
    ["A0", "A1", "A2"],
    ["B0", "B1", "B2"],
]
row_headers = ["Row A", "Row B"]
col_headers = ["Col 0", "Col 1", "Col 2"]

subplots_kwargs = dict(sharex=True, sharey=True, figsize=(10, 6))
fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)

font_kwargs = dict(fontfamily="monospace", fontweight="bold", fontsize="large")
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)

plt.show()

result: regular grid

If some axes spans multiple rows / cols, it gets a bit less straightforward to assign rows / cols headers correctly.
I didn’t managed to sort it out from inside the function, but being careful to the given row_headers and col_headers arguments is enough to make it work easily:

mosaic = [
    ["A0", "A1", "A1", "A2"],
    ["A0", "A1", "A1", "A2"],
    ["B0", "B1", "B1", "B2"],
]

row_headers = ["A", "A", "B"]  # or
row_headers = ["A", None, "B"]  # or
row_headers = {0: "A", 2: "B"}

col_headers = ["0", "1", "1", "2"]  # or
col_headers = ["0", "1", None, "2"]  # or
col_headers = {0: "0", 1: "1", 3: "2"}

fig, axes = plt.subplot_mosaic(mosaic, **subplots_kwargs)
add_headers(fig, col_headers=col_headers, row_headers=row_headers, **font_kwargs)
plt.show()

result: non-regular grid

Answered By: paime
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.