Setting xlabels for each suplot in a seaborn Pairgrid object plotted using matplotlib

Question:

I am trying to set individual xlabels for each subplot in a seaborn parigrid object, but the plot wont update and just shows me the xlables for the bottom most plot only.

g = sns.PairGrid(dat,x_vars = inputs, y_vars = outputs, hue = 'variable')

def scatter_plt(x, y, *a, **kw):
    if x.equals(y):
        kw["color"] = (0, 0, 0, 0)
    plt.scatter(x, y,*a, **kw)
    plt.xticks(rotation=90)
    plt.subplots_adjust(wspace=0.4, hspace=0.4)
        
g.map(scatter_plt)

I tried the following but it did not work as I saw the same plot as before.

xlabels,ylabels = [],[]

for ax in g.axes[-1,:]:
    xlabel = ax.xaxis.get_label_text()
    xlabels.append(xlabel)
for ax in g.axes[:,0]:
    ylabel = ax.yaxis.get_label_text()
    ylabels.append(ylabel)

for i in range(len(xlabels)):
    for j in range(len(ylabels)):
        g.axes[j,i].xaxis.set_label_text(xlabels[i])
        g.axes[j,i].yaxis.set_label_text(ylabels[j])

Asked By: Sherwin R

||

Answers:

Seaborn sets these internal labels invisible, so you explicitly need to set them visible again.

Here is how the code could look like. Some details have also changed:

  • Seaborn has some example datasets that can be used for quick testing. Here, the iris dataset is used for easy reproducibility.
  • plt.subplots_adjust(...) only needs to be called once, as it changes the full figure. Instead of plt.subplot_adjust(), plt.tight_layout() often works easier, as it tries to optimize all distances.
  • Assigning the xlabels and ylabels via list comprehension not only makes the code shorter, it also prevents errors and makes things easier to change.
  • In a similar reasoning, in Python it’s recommended to minimize the use of explicit indices. That’s why often constructions such as for i, xlabel in enumerate(xlabels) are seen.
import matplotlib.pyplot as plt
import seaborn as sns

iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, x_vars=iris.columns[0:4], y_vars=iris.columns[0:3], hue='species')

def scatter_plt(x, y, *a, **kw):
    if not x.equals(y):
        plt.scatter(x, y, *a, **kw)
    plt.tick_params(axis='x', rotation=90)

g.map(scatter_plt)

xlabels = [ax.xaxis.get_label_text() for ax in g.axes[-1, :]]
ylabels = [ax.yaxis.get_label_text() for ax in g.axes[:, 0]]

for i, xlabel in enumerate(xlabels):
    for j, ylabel in enumerate(ylabels):
        g.axes[j, i].set_xlabel(xlabel, visible=True)
        g.axes[j, i].set_ylabel(ylabel, visible=True)

plt.tight_layout()
plt.show()

sns.PairGrid with internal labels

Answered By: JohanC