Dynamic pandas subplots with matplotlib

Question:

I need help creating subplots in matplotlib dynamically from a pandas dataframe.
The data I am using is from data.word.

I have already created the viz but the plots have been created manually.

enter image description here

The reason why I need it dynamically is because I am going to apply a filter dynamically (in Power BI) and i need the graph to adjust to the filter.

This is what i have so far:
I imported the data and got it in the shape i need:

import matplotlib.pyplot as plt
import pandas as pd
import numpy as 

# read file from makeover monday year 2018 week 48
df = pd.read_csv(r'C:UsersRuth PozueloDocumentspython_examplesdata2018w48.csv',  usecols=["city", "category","item", "cost"], index_col=False, decimal=",")
df.head()

this is the table:
enter image description here

I then apply the filter that will come from Power BI dynamically:

df = df[df.category=='Party night'] 

and then I count the number of plots based on the number of items I get after I apply the filter:

itemCount = df['item'].nunique()         #number of plots

If I then plot the subplots:

fig, ax = plt.subplots( nrows=1, ncols=itemCount ,figsize=(30,10), sharey=True) 

I get the skeleton:
enter image description here

So far so good!
But now i am suck on how to feed the x axis to the loop to generate the subcategories. I am trying something like below, but nothing works.

#for i, ax in enumerate(axes.flatten()):
 #   ax.plot(??,cityValues, marker='o',markersize=25, lw=0, color="green") # The top-left axes

As I already have the code for the look and feel of the chart, annotations,ect, I would love to be able to use the plt.subplots method and I prefer not use seaborn if possible.

Any ideas on how to get his working?

Thanks in advance!

Asked By: ruthpozuelo

||

Answers:

Working Example below. I used seaborn to plot the bars but the idea is the same you can loop through the facets and increase a count. Starting from -1 so that your first count = 0, and use this as the axis label.

import seaborn as sns 

fig, ax = plt.subplots( nrows=1, ncols=itemCount ,figsize=(30,10), sharey=True)  

df['Cost'] = df['Cost'].astype(float) 
count = -1 
variables = df['Item'].unique() 

fig, axs = plt.subplots(1,itemCount , figsize=(25,70), sharex=False, sharey= False) 

for var in variables:
        count += 1 
        sns.barplot(ax=axs[count],data=df, x='Cost', y='City')
    
   
Answered By: DGradeci

The data was presented to us and we used it as the basis for our code. I prepared a list of columns and a list of coloring and looped through them. axes.rabel() is more memory efficient than axes.fatten(). This is because the list contains an object for each subplot, allowing for centralized configuration.

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

url='https://raw.githubusercontent.com/Curbal-Data-Labs/Matplotlib-Labs/master/2018w48.csv'
dataset = pd.read_csv(url)

dataset.drop_duplicates(['city','item'], inplace=True)
dataset.pivot_table(index='city', columns='item', values='cost', aggfunc='sum', margins = True).sort_values('All', ascending=True).drop('All', axis=1)

df = dataset.pivot_table(index='city', columns='item', values='cost', aggfunc='sum', margins = True).sort_values('All', ascending=True).drop('All', axis=1).sort_values('All', ascending=False, axis=1).drop('All').reset_index()

# comma replace
for c in df.columns[1:]:
    df[c] = df[c].str.replace(',','.').astype(float)

fig, axes = plt.subplots(nrows=1, ncols=5, figsize=(30,10), sharey=True)

colors = ['green','blue','red','black','brown'] 
col_names = ['Dinner','Drinks at Dinner','2 Longdrinks','Club entry','Cinema entry']

for i, (ax,col,c) in enumerate(zip(axes.ravel(), col_names, colors)):
    ax.plot(df.loc[:,col], df['city'], marker='o', markersize=25, lw=0, color=c)
    ax.set_title(col)
    for i,j in zip(df[col], df['city']):
        ax.annotate('$'+str(i), xy=(i, j), xytext=(i-4,j), color="white", fontsize=8)
    ax.set_xticks([])
    ax.spines[['top', 'right', 'left', 'bottom']].set_visible(False)
    ax.grid(True, axis='y', linestyle='solid', linewidth=2)
    ax.grid(True, axis='x', linestyle='solid', linewidth=0.2)
    ax.xaxis.tick_top() 
    ax.xaxis.set_label_position('top')
    ax.set_xlim(xmin=0, xmax=160)
    ax.xaxis.set_major_formatter('${x:1.0f}')
    ax.tick_params(labelsize=8, top=False, left=False)

plt.show()

enter image description here

Answered By: r-beginners
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.