How to use sns.catplot to visualize the change of a dataset

Question:

I have a dataset that changes mean and variance after time, but still is in the same range of y values. I want visualize the change of the data. Here is the plot with the features I’m looking for:
enter image description here

Unfortunately I don’t know how to create such a beauty. I tried the seaborn.catplot() function but I’m missing the option to visualize the mean and the connections of the individual datapoints.

import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

dist_1 = np.random.beta(2.,2.,250)
dist_2 = np.random.beta(3.,4.,250)
data = {"dist 1": dist_1, "dist 2": dist_2}
df = pd.DataFrame(data=data)
sns.catplot(data=df, s=3.)
plt.show()

The result would be something like this:

enter image description here

Any ideas to add the missing features?

Asked By: Lufr1ed

||

Answers:

To get a categorical plot you need to create a data frame that contains a categorical column. I made slight modifications to your code with respect to the column names –
import numpy as np
import seaborn as sns
import pandas as pd
import matplotlib.pyplot as plt

dist_1 = np.random.beta(2.,2.,250)
dist_2 = np.random.beta(3.,4.,250)
data = {"dist1": dist_1, "dist2": dist_2}
df = pd.DataFrame(data=data)
df['id'] = df.index

I added a column for the index of each row so that I can use it to convert this wide dataframe into a long one.

long_df = pd.wide_to_long(df, 'dist', 'id', 'cat').reset_index()
long_df.head()
#   id  cat      dist
#0   0    1  0.558567
#1   1    1  0.265868
#2   2    1  0.076470
#3   3    1  0.891855
#4   4    1  0.215208

Now I want to use the cat column on the x-axis of the catplot, dist on the y-axis and create a point plot for each row in the original dataframe.

sns.catplot(x="cat", y="dist", hue='id', kind="point", data=long_df, legend=False)
plt.show()

This gives you the following plot –

Catplot


Bonus: You can also do the same using a line plot and add the markers for mean.

mean_df = long_df.groupby('cat')['dist'].mean().reset_index()
mean_df
#   cat      dist
#0    1  0.535771
#1    2  0.399939
fig = plt.figure()
ax1 = fig.add_subplot(111)
sns.lineplot(x='cat', y='dist', hue='id', data=long_df, legend=False, ax=ax1)
ax1.plot(mean_df['cat'], mean_df['dist'], 'x')
plt.show()

Which gives you –

Line plot with mean added

The ‘X’ on each category represents the mean. Hope that helps!

Answered By: Ic3fr0g
Categories: questions Tags: , , ,
Answers are sorted by their score. The answer accepted by the question owner as the best is marked with
at the top-right corner.