Struggling with Matplotlib subplots in a for loop

Question:

I have the following code:

# Make plot of channels with gaps
fig, ax = plt.subplots(nrows=len(gap_list), ncols=1, figsize=(12,len(gap_list)), sharex=True, squeeze=False)

for ch in gap_list:
    i = gap_list.index(ch)
    resample_s = 4*ch_gap[ch]['rec_rate']
    ylabel = ch + ' (' + ch_gap[ch]['board'] +') - '+ ch_gap[ch]['unit']
    data = df[ch].resample(f'{resample_s}s').mean()
    is_nan = data.isnull()
    ax[i].fill_between(data.index, 0, (is_nan*data.max()), color='r', step='mid', linewidth='0')
    ax[i].plot(data.index, data, color='b', linestyle='-', marker=',', label=ylabel)
    ax[i].legend(loc='upper left')


plt.tight_layout()
plt.show()

Where gap_list is a list containing the column names from a pandas dataframe (df). The length of the list can be anywhere from 1 to 10. It works fine when nrows > 1. However when nrows == 1 i get an issue where it raises an exception:

'AxesSubplot' object is not subscriptable

So then i found the squeeze kwarg and set that to false, all good i thought but now the code raises this exception:

'numpy.ndarray' object has no attribute 'fill_between'

So then i took a different tack and set the figure outside of the loop and put the subplot creation inside the loop:

fig = plt.figure(figsize=(12,len(gap_list)))

The created each axis in the for loop as below:

ax = plt.subplot(len(gap_list), 1, i+1)

This works for both nrows=1 and norws > 1. However i then can’t find a nice way of making all the subplots share the X axis. In the original method i could just set sharex=True for plt.subplots().

So it feels like the original method was more on the right lines, but with a missing ingredient to get the nrows=1 case handled better.

Asked By: almonde

||

Answers:

I think it would be most straightforward to keep your original code but just check if ax is a numpy array.

When nrows > 1, ax will be a numpy array of matplotlib axes, so index into ax. When nrows == 1, ax will just be the matplotlib axes, so use it directly.

import numpy as np

...

for ch in gap_list:
    
    ...
    
    # if `ax` is a numpy array then index it, else just use `ax`
    ax_i = ax[i] if isinstance(ax, np.ndarray) else ax

    # now just use the `ax_i` handle
    ax_i.fill_between(data.index, 0, (is_nan*data.max()), color='r', step='mid', linewidth='0')
    ax_i.plot(data.index, data, color='b', linestyle='-', marker=',', label=ylabel)
    ax_i.legend(loc='upper left')
Answered By: tdy

Variants on this question have been asked several times. tdy’s answer is the best I have seen since it avoids low-level manipulation of ax. The problem arises because matplotlib developers decided to make ax an axes object when the "fig, ax = …" statement creates a single axes object, while ax is an array of axes objects if the statment creates multiple objects.

As almonde’s example illustrates, the programmer often wants to write code that will work the same way when there are one or more axes objects since the required number may only be known at run time.

An alternative solution to tdy’s is to force creation of an area of axes objects by asking for one more than needed, then deleting the last element in the ax array. In almonde’s example (stripping away irrelevant keywords) we replace

fig, ax = plt.subplots(nrows=len(gap_list), ncols=1)

with the following two statements:

fig, ax = plt.subplots(nrows=len(gap_list) + 1, ncols=1)
fig.delaxes(ax[-1])

ax is now an array. If len(gap_list) = 1, i = 0 on the single iteration of the for loop, and the references to ax[i] will work as expected.

Answered By: CaveMan
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.