is there gridExtra/cowplot type package in Python to use with Plotnine to align subplots (i.e. marginal distributions)

Question:

This is related to the following posts on aligning subplots in ggplot2 in R:

Scatterplot with marginal histograms in ggplot2

How to plot Scatters with marginal density plots use Python?

Perfectly align several plots

There are many more links on StackOverflow on this topic.

Anyhow, I haven’t been able to find any documentation on how to do this with Plotnine (ggplot2 implementation) in Python.

Can this be done at all in Python? If not has there not been any demand for a solution to this problem?

Thank you for your guidance.

UPDATE:

I found this link MEP for a matplotlib geometry manager #1109
which seems to indicate that the functionality is now present in matplotlib to do this. The question is has plotnine been extended to do this?

Asked By: codingknob

||

Answers:

So I dont know enough about plotnine to say whether it can do this but it is VERY easy in seaborn: https://seaborn.pydata.org/generated/seaborn.jointplot.html

>>> import numpy as np, pandas as pd; np.random.seed(0)
>>> import seaborn as sns; sns.set(style="white", color_codes=True)
>>> tips = sns.load_dataset("tips")
>>> g = sns.jointplot(x="total_bill", y="tip", data=tips)
Answered By: Gus

By combining patchworklib and plotnine v0.9.0, we can plot scatters with marginal density plots as follows.

import patchworklib as pw
from plotnine import *
from plotnine.data import *

g1 = pw.load_ggplot(ggplot(mpg, aes(x='cty', color='drv', fill='drv')) +
                    geom_density(aes(y=after_stat('count')), alpha=0.1) +
                    scale_color_discrete(guide=False) +
                    theme(axis_ticks_major_x=element_blank(),
                          axis_text_x =element_blank(),
                          axis_title_x=element_blank(),
                          axis_text_y =element_text(size=12),
                          axis_title_y=element_text(size=14),
                          legend_position="none"),
                    figsize=(4,1))

g2 = pw.load_ggplot(ggplot(mpg, aes(x='hwy', color='drv', fill='drv')) +
                    geom_density(aes(y=after_stat('count')), alpha=0.1) +
                    coord_flip() +
                    theme(axis_ticks_major_y=element_blank(),
                          axis_text_y =element_blank(),
                          axis_title_y=element_blank(),
                          axis_text_x =element_text(size=12),
                          axis_title_x=element_text(size=14)
                         ),
                    figsize=(1,4))

g3 = pw.load_ggplot(ggplot(mpg) +
                    geom_point(aes(x="cty", y="hwy", color="drv")) +
                    scale_color_discrete(guide=False) +
                    theme(axis_text =element_text(size=12),
                          axis_title=element_text(size=14)
                         ),
                    figsize=(4,4))

pw.param["margin"] = 0.2
(g1/(g3|g2)[g3]).savefig() #By specifying g3 in (g3|g2), g1 is positioned exactly on g3. 

enter image description here

Answered By: Hideto