How to draw a figure by seaborn pairplot in several rows?

Question:

I have a dataset with 76 features and 1 dependent variable (y). I use seaborn to draw pairplot between features and y in Jupyter notebook. Since the No. of features is high, size of plot for every feature is very small, as can be seen below:

enter image description here

I am looking for a way to draw pairplot in several rows. Also, I don’t want to copy and paste pairplot code in several cells in notebook. I am looking for a way to make this figure automatically.

The code I am using (I cannot share dataset, so I use a sample dataset):

from sklearn.datasets import load_boston
import math
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

X, y = load_boston(return_X_y=True)
X = pd.DataFrame(X)
y = pd.DataFrame(y)
data = pd.concat([X, y], axis=1)

y_name = 'y'
features_names = [f'feature_{i}' for i in range(1, X.shape[1]+1)]  
column_names = features_names + [y_name]
data.columns = column_names

plot_size=7
num_plots_x=5   # No. of plots in every row
num_plots_y = math.ceil(len(features_names)/num_plots_x)   # No. of plots in y direction

fig = plt.figure(figsize=(plot_size*num_plots_y, plot_size*num_plots_x), facecolor='white')
axes = [fig.add_subplot(num_plots_y,1,i+1) for i in range(num_plots_y)]   

for i, ax in enumerate(axes):   
    start_index = i * num_plots_x
    end_index = (i+1) * num_plots_x
    if end_index > len(features_names): end_index = len(features_names)
    sns.pairplot(x_vars=features_names[start_index:end_index], y_vars=y_name, data = data)

plt.savefig('figure.png')

The above code has two problems. It shows empty box at the top of the figure and then it shows the pairplots. Following is part of the figure that I get.

enter image description here

Second problem is that it only saves the last row as png file, not the whole figure.

If you have any idea to solve this, please let me know. Thank you.

Asked By: Mohammad

||

Answers:

When I run it directly (python script.py) then it opens every row in separated window – so it treats it as separated objects and it saves in file only last object.

Other problem is that sns doesn’t need fig and axes – it can’t use subplots to put all on one image – and when I remove fig axes then it stops showing first window with empty box.


I found that FacetGrid has col_wrap to put in many rows. And I found that someone suggested to add this col_wrap in pairplotAdd parameter col_wrap to pairplot #2121 and there is also example how to FacetGrid with scatterplot instead of pairplot and then it can use col_wrap.


Here is code which use FacetGrid with col_wrap

from sklearn.datasets import load_boston
import math
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd

X, y = load_boston(return_X_y=True)
X = pd.DataFrame(X)
y = pd.DataFrame(y)
data = pd.concat([X, y], axis=1)

y_name = 'y'
features_names = [f'feature_{i}' for i in range(1, X.shape[1]+1)]  
column_names = features_names + [y_name]
data.columns = column_names

plot_size=7
num_plots_x=5   # No. of plots in every row
num_plots_y = math.ceil(len(features_names)/num_plots_x)   # No. of plots in y direction

'''
for i in range(num_plots_y):
    start = i * num_plots_x
    end = start + num_plots_x
    sns.pairplot(x_vars=features_names[start:end], y_vars=y_name, data=data)
'''

g = sns.FacetGrid(pd.DataFrame(features_names), col=0, col_wrap=4, sharex=False)
for ax, x_var in zip(g.axes, features_names):
    sns.scatterplot(data=data, x=x_var, y=y_name, ax=ax)
g.tight_layout()

plt.savefig('figure.png')
plt.show()

Result ('figure.png'):

enter image description here

Answered By: furas
from sklearn.datasets import load_boston
import math
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

X, y = load_boston(return_X_y=True)
X = pd.DataFrame(X)
y = pd.DataFrame(y)
data = pd.concat([X, y], axis=1)

y_name = 'y'
features_names = [f'feature_{i}' for i in range(1, X.shape[1]+1)]  
column_names = features_names + [y_name]
data.columns = column_names


fig, axes = plt.subplots(int(np.ceil(len(column_names)/3)), 3, figsize = (20,20))
for i, att in zip(range(len(column_names)), column_names):
    sns.scatterplot(ax = axes[int(np.floor(i/3))][i%3], 
                    x = att, y = 'y', data = data,
                    alpha = 0.8)

enter image description here

Answered By: Rajesh kaki
Categories: questions Tags: , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.