How to normalize histograms in joinplot margins

Question:

How can I normalize the histograms in sns.jointplot?

Right now I’m getting this plot

enter image description here

However, instead of showing 0, 200, 400 on the axis, I would like it to be a fraction of the total number of the dataset

This is my code:

import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib import rcParams
sns.set(style='white')

# sample x & y: doesn't match the top sample image
x = np.random.exponential(70, 100)
y = np.random.exponential(0.005, 100)

g = sns.jointplot(x, y, xlim = [30, 800], ylim = [0.0007, 0.023], marker = '*', s = 10, color = 'k',
             ec = 'k', marginal_kws=dict(bins=10, fill = True, log_scale=True), 
              height=1, ratio=6, marginal_ticks=True)
g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=10, fill = True, thresh = 0.1)#, cmap = 'jet')

g.ax_joint.invert_xaxis()
g.fig.set_size_inches(14,10)
g.ax_joint.axvline(x = 70, color = 'red', ls = '--', lw = 2)#, label = '70 K')
g.ax_joint.axvline(x = 650, color = 'red', ls = '--', lw = 2)#, label = '70 K')
g.ax_joint.axhline(y = 0.005, color = 'red', ls = '--', lw = 2)#, label = '70 K')
Asked By: JohnGoodWill

||

Answers:

You can use g.ax_marg_x and g.ax_marg_y to access the axes objects corresponding to the marginal distributions. Then you can iterate over Axes.get_children() to filter out the "bar" objects which are represented as matplotlib.patches.Rectangle. Then the height/width of these rectangle objects can be adjusted and the axes limits can be adjusted as well:

import matplotlib.patches


def normalize_marginal(joint_grid, dim):
    get_size = lambda bar: getattr(bar, f'get_{"height" if dim == "x" else "width"}')()
    set_size = lambda bar, val: getattr(bar, f'set_{"height" if dim == "x" else "width"}')(val)
    set_lim = lambda ax, val: getattr(ax, f'set_{"y" if dim == "x" else "x"}lim')(val)

    ax = getattr(g, f'ax_marg_{dim}')

    bars = [obj for obj in ax.get_children() if isinstance(obj, matplotlib.patches.Rectangle)]
    sizes = [get_size(bar) for bar in bars]

    for bar, size in zip(bars, sizes):
        set_size(bar, size / sum(sizes))

    set_lim(ax, [0, 1])  # customize (e.g. `set_lim(ax, [0, 1.2*max(sizes)])`)


normalize_marginal(g, 'x')
normalize_marginal(g, 'y')
Answered By: a_guest
  • As stated in seaborn.jointplot, marginal_kws should accept parameters from the type of plot in the margins (e.g. sns.histplot in this case).
  • From sns.histplot: stat='probability': or proportion: normalize such that bar heights sum to 1.
    • Pass this parameter to marginal_kws
  • If you want the total area of the histogram to equal 1, then use 'density'.
  • If there are multiple groups of data from using hue, then also consider adding common_bins=False and common_norm=False to marginal_kws.
  • Tested in python 3.10 and seaborn 0.11.2
import seaborn as sns
import numpy as np

# test data
np.random.seed(365)
x = np.random.exponential(70, 100)
y = np.random.exponential(0.005, 100)

# plot
g = sns.jointplot(x=x, y=y,  marker='*', s=10, color='k', height=7, ec='k', marginal_ticks=True,
                  marginal_kws=dict(bins=10, fill=True, log_scale=True, color='r', stat='probability'))
g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=10, fill = True, thresh = 0.1)
g.ax_joint.invert_xaxis()

enter image description here


  • Instead of using xlim = [30, 800] and ylim = [0.0007, 0.023], which truncates the plot, you can mask the data first.
    • Both masks are applied to x and y to ensure the resulting datasets are the same length.
  • This is optional, depending on how the data should be presented, and the bins will be normalized only on the selected data.
# test data
np.random.seed(365)
x = np.random.exponential(70, 100)
y = np.random.exponential(0.005, 100)

# masks
x_mask = (x > 30) & (x < 800)
y_mask = (y > 0.0007) & (y < 0.023)

# selected data
x_sel = x[x_mask & y_mask]
y_sel = y[x_mask & y_mask]

# plot
g = sns.jointplot(x=x_sel, y=y_sel,  marker='*', s=10, color='k', height=7, ec='k', marginal_ticks=True,
                  marginal_kws=dict(bins=10, fill=True, log_scale=True, color='r', stat='probability', common_bins=False, common_norm=False))
g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=10, fill = True, thresh = 0.1)
g.ax_joint.invert_xaxis()

enter image description here

Answered By: Trenton McKinney